".github/vscode:/vscode.git/clone" did not exist on "04ae9b80d4b5676911a9dfbb27805d477b1d889f"
transformer.py 56.1 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Transformer."""
import math
5
from contextlib import nullcontext
6
import torch
7
import torch.nn.functional as F
8

9
from megatron import get_timers, get_args, core, get_num_microbatches
10
from .module import MegatronModule
11
from megatron.core import mpu, tensor_parallel
12
from megatron.core.enums import ModelType
13
from megatron.model import LayerNorm
14
from megatron.model.enums import AttnMaskType, LayerType, AttnType
15
16
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
Mostofa Patwary's avatar
Mostofa Patwary committed
17
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
18
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
19

20
21
22
23
24
25
26
27
28
29
try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    flash_attn_unpadded_func = None

30
31
32
33
34
35
36
37
38
39
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
40
    Transformer takes input of size [s, b, h] and returns a
41
42
43
44
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

45
class DropPath(MegatronModule):
46
    """Drop paths (Stochastic Depth) per sample
47
48
49
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
    def __init__(self, drop_prob=0.):
51
52
53
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
54
    def forward(self, hidden_state):
55
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
            return hidden_state
57
58
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
59
60
        # hidden_state: [s, b, h]
        shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
61
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
62
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
63
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        output = hidden_state.div(keep_prob) * random_tensor
65
66
        return output

67
68
69
70
71
72
73
74
75
76
77
def _args_to_kwargs():
    args = get_args()

    common_kwargs = {
        "params_dtype": args.params_dtype,
        "use_cpu_initialization": args.use_cpu_initialization,
        "perform_initialization": args.perform_initialization,
        "gradient_accumulation_fusion": args.gradient_accumulation_fusion,
        "sequence_parallel_enabled": args.sequence_parallel,
    }
    return common_kwargs
78

79
80
81
82
83
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
84
    state back into h hidden dimension.
85
86
    """

87
    def __init__(self, init_method, output_layer_init_method):
88
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
89
        args = get_args()
90

91
        self.add_bias = args.add_bias_linear
92

93
        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
94
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
95
            args.hidden_size,
96
97
            args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size,
            bias=self.add_bias,
98
            gather_output=False,
99
            init_method=init_method,
100
101
102
            skip_bias_add=True,
            async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
            **_args_to_kwargs())
103

104
105
106
107
        self.bias_gelu_fusion = False
        self.activation_func = None
        self.swiglu = args.swiglu

108
109
110
111
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
112
113
114
115
116
117
118
119
120
121
122
123
        elif args.swiglu:
            def swiglu(x):
                x = torch.chunk(x, 2, dim=-1)
                return F.silu(x[0]) * x[1]
            self.activation_func = swiglu
        elif args.squared_relu:
            def squared_relu(x):
                return torch.pow(F.relu(x), 2)
            self.activation_func = squared_relu
        else:
            self.bias_gelu_fusion = args.bias_gelu_fusion
            self.activation_func = F.gelu
124
125

        # Project back to h.
126
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
127
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
128
            args.hidden_size,
129
            bias=self.add_bias,
130
            input_is_parallel=True,
131
            init_method=output_layer_init_method,
132
133
            skip_bias_add=True,
            **_args_to_kwargs())
134

135
136
    def forward(self, hidden_states):

137
138
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
139

140
        if self.bias_gelu_fusion:
141
142
143
            assert self.add_bias is True
            assert self.activation_func == F.gelu
            intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
144
        else:
Jared Casper's avatar
Jared Casper committed
145
            if bias_parallel is not None:
146
147
                intermediate_parallel = intermediate_parallel + bias_parallel
            intermediate_parallel = self.activation_func(intermediate_parallel)
148
149
150
151

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
152

rprenger's avatar
rprenger committed
153
154
155
156
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
157
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
158
159
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
160
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
161
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
162
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
163
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
164

rprenger's avatar
rprenger committed
165
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
166
167
168
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
169
170
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
171
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
172
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
174

rprenger's avatar
rprenger committed
175
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
176
        # Converting [s, b, h] to [s*b, h].
177
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
178
179
180
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
181
182
183

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
184
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
185

rprenger's avatar
rprenger committed
186
        for expert_num, expert in enumerate(self.experts):
187
188
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
189
190
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
191
192
193
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
194
195
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
196
197
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
198
199

        return output_total, output_bias_total
200

201
202

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
203

204
205
206
207
208
209
210
211
212
213
214
215
216
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
217
        self.sequence_parallel = args.sequence_parallel
218
219
220
221

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
222
        world_size = mpu.get_tensor_model_parallel_world_size()
223
224
225
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
226
            projection_size, args.num_attention_heads)
227
        self.num_attention_heads_per_partition = core.utils.divide(
228
            args.num_attention_heads, world_size)
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
269
        # preallocting input tensor: [b * np, sq, sk]
270
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
271
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
272
            query_layer.dtype, "mpu")
273
274
275

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
276
            matmul_input_buffer,
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
294
        if not self.sequence_parallel:
295
            with tensor_parallel.get_cuda_rng_tracker().fork():
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
class FlashSelfAttention(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
                                                      'e.g., with pip install flash-attn')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda
        batch_size, seqlen = q.shape[0], q.shape[1]
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        max_s = seqlen
        cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                  device=q.device)
        output = flash_attn_unpadded_func(
            q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
            self.dropout_p if self.training else 0.0,
            softmax_scale=self.softmax_scale, causal=self.causal
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output


380
class ParallelAttention(MegatronModule):
381
382
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
383
    Self-attention layer takes input with size [s, b, h]
384
385
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
386

387
    def __init__(self, init_method,
388
389
390
391
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
392
        args = get_args()
393
        self.layer_number = max(1, layer_number)
394
395
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
396
        self.params_dtype = args.params_dtype
397
398
399
400
401
402
403
404
405
406
407
408
409
        self.sequence_parallel = args.sequence_parallel

        self.use_flash_attn = args.use_flash_attn
        if self.use_flash_attn:
            if flash_attn_unpadded_func is None:
                raise ImportError('FlashAttention is not installed, please install with '
                                  'pip install flash-attn')
            assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
                                                          'self-attention for now')
            assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
                                                                'supports causal mask for now')
            if rearrange is None:
                raise ImportError('einops is not installed, please install with pip install einops')
410
411

        projection_size = args.kv_channels * args.num_attention_heads
412
413

        # Per attention head and per partition values.
414
        world_size = mpu.get_tensor_model_parallel_world_size()
415
        self.hidden_size_per_attention_head = core.utils.divide(
416
            projection_size, args.num_attention_heads)
417
        self.num_attention_heads_per_partition = core.utils.divide(
Mohammad's avatar
Mohammad committed
418
            args.num_attention_heads, world_size)
419
420

        # Strided linear layer.
421
        if attention_type == AttnType.self_attn:
422
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
423
424
                args.hidden_size,
                3 * projection_size,
425
                bias=args.add_bias_linear,
426
                gather_output=False,
427
428
429
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
430
431
        else:
            assert attention_type == AttnType.cross_attn
432
            self.query = tensor_parallel.ColumnParallelLinear(
433
434
                args.hidden_size,
                projection_size,
435
                bias=args.add_bias_linear,
436
                gather_output=False,
437
438
439
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
440

441

442
            self.key_value = tensor_parallel.ColumnParallelLinear(
443
444
                args.hidden_size,
                2 * projection_size,
445
                bias=args.add_bias_linear,
446
                gather_output=False,
447
448
449
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
450

451
452
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
453
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
454

455
456
457
458
459
        if self.use_flash_attn:
            self.core_attention_flash = FlashSelfAttention(
                causal=True, attention_dropout=args.attention_dropout
            )

460
        # Output.
461
        self.dense = tensor_parallel.RowParallelLinear(
462
            projection_size,
Mohammad's avatar
Mohammad committed
463
            args.hidden_size,
464
            bias=args.add_bias_linear,
465
            input_is_parallel=True,
466
            init_method=output_layer_init_method,
467
468
            skip_bias_add=True,
            **_args_to_kwargs())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
469

470
    def _checkpointed_attention_forward(self, query_layer, key_layer,
Mostofa Patwary's avatar
Mostofa Patwary committed
471
472
                                        value_layer, attention_mask,
                                        rotary_pos_emb=None):
473
474
475
476
477
478
479
480
481
482
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

Mostofa Patwary's avatar
Mostofa Patwary committed
483
484
485
        q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
            else rotary_pos_emb

486
        hidden_states = tensor_parallel.checkpoint(
487
            custom_forward,
Mostofa Patwary's avatar
Mostofa Patwary committed
488
489
            False, query_layer, key_layer, value_layer, attention_mask,
            q_pos_emb, k_pos_emb)
490
491

        return hidden_states
492
493
494
495
496
497
498
499
500
501
502

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
503
504
                encoder_output=None, inference_params=None,
                rotary_pos_emb=None):
505
        # hidden_states: [sq, b, h]
506

507
508
509
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
Mostofa Patwary's avatar
Mostofa Patwary committed
510
        is_first_step = False
mshoeybi's avatar
mshoeybi committed
511
        if inference_params:
512
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
513
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
514
                inf_max_batch_size = inference_params.max_batch_size
515
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
516
                    inf_max_seq_len, inf_max_batch_size)
517
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
518
                    inf_max_seq_len, inf_max_batch_size)
519
520
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
Mostofa Patwary's avatar
Mostofa Patwary committed
521
                is_first_step = True
522
523
524
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
525

526
527
528
        # =====================
        # Query, Key, and Value
        # =====================
529

530
531
532
533
534
535
536
537
538
539
540
541
542
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
543
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
544
545
546
547
548
549
550
551
552
553
554
555
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
556
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
557
558
559
560
561
562
563
564

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
565

mshoeybi's avatar
mshoeybi committed
566
567
568
        # ==================================
        # Adjust key and value for inference
        # ==================================
569

Mostofa Patwary's avatar
Mostofa Patwary committed
570
571
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
572
573
574
575
            if isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = rotary_pos_emb
            else:
                rotary_pos_emb = ((rotary_pos_emb,) * 2)
Mostofa Patwary's avatar
Mostofa Patwary committed
576

mshoeybi's avatar
mshoeybi committed
577
        if inference_params:
mshoeybi's avatar
mshoeybi committed
578
579
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
580
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
581
582
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
583
            assert sequence_end <= inference_key_memory.size(0)
584
            # Copy key and values.
585
586
587
588
589
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
590
                :sequence_end, batch_start:batch_end, ...]
591
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
592
                :sequence_end, batch_start:batch_end, ...]
593

Mostofa Patwary's avatar
Mostofa Patwary committed
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)


616
617
618
        # ==================================
        # core attention computation
        # ==================================
619

Mostofa Patwary's avatar
Mostofa Patwary committed
620
621
622
623
624
625
626
627
628
629
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

630
631
632
633
634
635
636
        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
637
        else:
638
639
640
641
642
643
644
645
            q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
                       for x in (query_layer, key_layer, value_layer)]
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
                    context_layer = self.core_attention_flash(q, k, v)
            else:
                context_layer = self.core_attention_flash(q, k, v)
            context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
646
647

        # =================
648
        # Output. [sq, b, h]
649
650
651
        # =================

        output, bias = self.dense(context_layer)
652

653
654
655
        return output, bias


656
def bias_dropout_add(x, bias, residual, prob, training):
657
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
658
659
660
    if bias is not None:
        x = x + bias
    out = torch.nn.functional.dropout(x, p=prob, training=training)
661
662
663
664
665
666
667
668
669
670
671
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
672
673
674
675
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
676
677
678
679
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
680
681
682
683
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
684
    return bias_dropout_add(x, bias, residual, prob, False)
685
686
687
688
689


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
690
    Transformer layer takes input with size [s, b, h] and returns an
691
692
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
693

694
695
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
696
697
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
698
        args = get_args()
699
700

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
701
        self.layer_number = layer_number
702
        self.layer_type = layer_type
703
704

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
705
            = args.apply_residual_connection_post_layernorm
706

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
707
708
709
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

710
711
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
712
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
713
            eps=args.layernorm_epsilon,
714
            no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
715
            sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
716
            apply_layernorm_1p=args.apply_layernorm_1p)
717
718

        # Self attention.
719
720
721
722
723
724
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
725
726
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
727
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
728

729
        # Layernorm on the attention output
730
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
731
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
732
            eps=args.layernorm_epsilon,
733
            no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
734
            sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
735
            apply_layernorm_1p=args.apply_layernorm_1p)
736

737
738
739
740
741
742
743
744
745
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
746
                eps=args.layernorm_epsilon,
747
                no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
748
                sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
749
                apply_layernorm_1p=args.apply_layernorm_1p)
750

751
        # MLP
rprenger's avatar
rprenger committed
752
753
754
755
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
756

757
758
759
760
761
762
763
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

764
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
765
                encoder_output=None, enc_dec_attn_mask=None,
Mostofa Patwary's avatar
Mostofa Patwary committed
766
                inference_params=None, rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
767
        # hidden_states: [s, b, h]
768

769
        # Layer norm at the beginning of the transformer layer.
770
771
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
772
        attention_output, attention_bias = \
773
774
775
            self.self_attention(
                layernorm_output,
                attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
776
                inference_params=inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
777
                rotary_pos_emb=rotary_pos_emb)
778

779
780
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
781
782
783
784
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
785
        if self.drop_path is None:
786
787
788
789
790
791
792
793
794
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
795
            else:
796
                bias_dropout_add_func = get_bias_dropout_add(self.training)
797

798
799
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)
800
            with self.bias_dropout_add_exec_handler():
801
802
                layernorm_input = bias_dropout_add_func(
                    attention_output,
803
                    attention_bias,
804
805
806
807
808
809
810
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
811

812
813
814
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

815
816
817
818
819
820
821
822
823
824
825
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

826
827
828
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)

829
            with self.bias_dropout_add_exec_handler():
830
831
                layernorm_input = bias_dropout_add_func(
                    attention_output,
832
                    attention_bias,
833
834
835
836
837
838
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

839
        # MLP.
840
        mlp_output, mlp_bias = self.mlp(layernorm_output)
841

842
843
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
844
            residual = layernorm_output
845
        else:
846
847
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
848
        if self.drop_path is None:
849
850
            if mlp_bias is not None:
                mlp_bias = mlp_bias.expand_as(residual)
851
            with self.bias_dropout_add_exec_handler():
852
853
                output = bias_dropout_add_func(
                    mlp_output,
854
                    mlp_bias,
855
856
                    residual,
                    self.hidden_dropout)
857
858
859
860
861
862
863

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
864
865
866
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
867

868
        else:
869
870
871
            if mlp_bias is not None:
                mlp_output = mlp_output + mlp_bias
            out = torch.nn.functional.dropout(mlp_output,
872
873
874
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
875
876
877
878

        return output


879
880
881
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
882
    The sole purpose of this layer is for when a standalone embedding layer
883
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
884
885
886
887
888
889
890
891
892
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
893
894
895
896
897
898
899
900
901
902
903
904
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


Jared Casper's avatar
Jared Casper committed
905
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
906
    """Compute the number of transformer layers resident on the current rank."""
Jared Casper's avatar
Jared Casper committed
907
    if mpu.get_pipeline_model_parallel_world_size() > 1:
908
909
910
911
912
913
914
915
916
917
918
919
920
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
                if args.standalone_embedding_stage else
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
Jared Casper's avatar
Jared Casper committed
921
922
923
924
            assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
                    'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
            assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
                    'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
Jared Casper's avatar
Jared Casper committed
925
            if mpu.is_pipeline_stage_before_split():
926
927
928
                num_layers = (
                    0
                    if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
929
                    and mpu.get_pipeline_model_parallel_rank() == 0 else
Jared Casper's avatar
Jared Casper committed
930
                    args.encoder_num_layers // num_ranks_in_encoder
931
932
                )
            else:
Jared Casper's avatar
Jared Casper committed
933
                num_layers = args.decoder_num_layers // num_ranks_in_decoder
934
        else:
Jared Casper's avatar
Jared Casper committed
935
            assert args.num_layers == args.encoder_num_layers
936
937
938
939
940
941
942
943
944
945
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
            num_layers = (
                0
                if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
946
                and mpu.get_pipeline_model_parallel_rank() == 0 else
947
948
949
                args.num_layers // args.transformer_pipeline_model_parallel_size
            )
    else:
Jared Casper's avatar
Jared Casper committed
950
951
952
953
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
954
955
956
    return num_layers


957
958
959
class ParallelTransformer(MegatronModule):
    """Transformer class."""

960
    def __init__(self, init_method, output_layer_init_method,
961
                 layer_type=LayerType.encoder,
962
                 self_attn_mask_type=AttnMaskType.padding,
963
                 post_layer_norm=True,
964
965
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
966
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
967
        args = get_args()
968

969
970
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
971
        self.bf16 = args.bf16
972
        self.fp32_residual_connection = args.fp32_residual_connection
973
        self.post_layer_norm = post_layer_norm
974
975
976
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
977
        self.drop_path_rate = drop_path_rate
978
        self.transformer_impl = args.transformer_impl
979

980
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
981
982
983
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
984
985
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
986

Vijay Korthikanti's avatar
Vijay Korthikanti committed
987
        self.sequence_parallel = args.sequence_parallel
988

989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        # Transformer Engine Init.
        if self.transformer_impl == 'transformer_engine':
            global transformer_engine
            import transformer_engine
        self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
        self.fp8_recipe = None
        self.fp8_group = mpu.get_data_parallel_group()
        if self.use_fp8:
            if args.fp8_e4m3:
                fp8_format = transformer_engine.common.recipe.Format.E4M3
            elif args.fp8_hybrid:
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
            self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
                margin=args.fp8_margin,
                interval=args.fp8_interval,
                fp8_format=fp8_format,
                amax_history_len=args.fp8_amax_history_len,
                amax_compute_algo=args.fp8_amax_compute_algo,
                override_linear_precision=(False, False, not args.fp8_wgrad),
            )

        self.num_microbatches_in_previous_step = -1
        self.microbatch_count = 0
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'

1014
        # Number of layers.
1015
        self.num_layers = _get_num_layers(
1016
1017
1018
            args,
            args.model_type == ModelType.encoder_and_decoder,
            layer_type == LayerType.decoder)
Mohammad's avatar
Mohammad committed
1019

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1020
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
1021

Mohammad's avatar
Mohammad committed
1022
1023
        # Transformer layers.
        def build_layer(layer_number):
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            if args.transformer_impl == 'local':
                return ParallelTransformerLayer(
                    init_method,
                    output_layer_init_method,
                    layer_number,
                    layer_type=layer_type,
                    self_attn_mask_type=self_attn_mask_type,
                    drop_path_rate=self.drop_path_rates[layer_number - 1])
            else:
                return transformer_engine.pytorch.TransformerLayer(
                    args.hidden_size,
                    args.ffn_hidden_size,
                    args.num_attention_heads,
                    layernorm_epsilon=args.layernorm_epsilon,
                    hidden_dropout=args.hidden_dropout,
                    attention_dropout=args.attention_dropout,
                    init_method=init_method,
                    output_layer_init_method=output_layer_init_method,
                    layer_number=layer_number,
                    kv_channels=args.kv_channels,
                    self_attn_mask_type=self_attn_mask_type.name,
                    tp_group=mpu.get_tensor_model_parallel_group(),
                    get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
                    fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
                    apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
                    attention_softmax_in_fp32=args.attention_softmax_in_fp32,
                    seq_length=args.seq_length,
                    micro_batch_size=args.micro_batch_size,
                    sequence_parallel=args.sequence_parallel,
                    params_dtype=args.params_dtype,
                    apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
                    output_layernorm=False,
                    layer_type="encoder",
                    drop_path_rate=self.drop_path_rates[layer_number - 1],
                    set_parallel_mode=True,
                    fuse_qkv_params=True)

1061
1062
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
1063
1064
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1065
            assert args.model_type != ModelType.encoder_and_decoder
1066
1067
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
1068
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
1069
1070
1071
1072
1073
1074
1075
1076
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
1077
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
1078
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
1079
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
1080
        else:
1081
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1082
            if args.model_type == ModelType.encoder_and_decoder and \
1083
1084
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1085
1086
1087
1088
1089
1090
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
1091
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
1092

1093
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
1094
            # When a standalone embedding stage is used (e.g.,
1095
            # args.standalone_embedding_stage == True), virtual pipeline ranks
1096
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
1097
1098
1099
1100
1101
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
1102
1103
1104
1105
1106
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
1107

1108
        if self.post_process and self.post_layer_norm:
1109
1110
1111
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
1112
                eps=args.layernorm_epsilon,
1113
                no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
1114
                sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
1115
                apply_layernorm_1p=args.apply_layernorm_1p)
1116

Mohammad's avatar
Mohammad committed
1117
    def _get_layer(self, layer_number):
1118
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
1119

1120
    def _checkpointed_forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1121
1122
                              encoder_output, enc_dec_attn_mask,
                              rotary_pos_emb, is_first_microbatch):
1123
        """Forward method with activation checkpointing."""
1124
1125
        def custom(start, end, is_transformer_engine=False):
            def custom_forward(*args, **kwargs):
1126
                x_, *args = args
Mohammad's avatar
Mohammad committed
1127
1128
                for index in range(start, end):
                    layer = self._get_layer(index)
1129
                    x_ = layer(x_, *args, **kwargs)
1130
                return x_
1131
1132
1133
1134
1135
1136
            def custom_forward_transformer_engine(*args, **kwargs):
                return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
            if not is_transformer_engine:
                return custom_forward
            else:
                return custom_forward_transformer_engine
1137

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1138
        if self.recompute_method == 'uniform':
1139
1140
1141
1142
1143
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
1144
1145
1146
1147
1148
1149
                if self.transformer_impl == 'transformer_engine':
                    hidden_states = transformer_engine.pytorch.distributed.checkpoint(
                        custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
                        self.distribute_saved_activations,
                        tensor_parallel.get_cuda_rng_tracker,
                        mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1150
1151
                        hidden_states, attention_mask, encoder_output,
                        enc_dec_attn_mask, rotary_pos_emb)
1152
1153
1154
1155
                else:
                    hidden_states = tensor_parallel.checkpoint(
                        custom(l, l + self.recompute_num_layers),
                        self.distribute_saved_activations,
Mostofa Patwary's avatar
Mostofa Patwary committed
1156
1157
                        hidden_states, attention_mask, encoder_output,
                        enc_dec_attn_mask, rotary_pos_emb)
1158

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1159
                l += self.recompute_num_layers
1160

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1161
        elif self.recompute_method == 'block':
1162
1163
1164
1165
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1166
                if l < self.recompute_num_layers:
1167
1168
1169
1170
1171
1172
                    if self.transformer_impl == 'transformer_engine':
                        hidden_states = transformer_engine.pytorch.distributed.checkpoint(
                            custom(l, l + 1, is_transformer_engine=True),
                            self.distribute_saved_activations,
                            tensor_parallel.get_cuda_rng_tracker,
                            mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1173
1174
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1175
1176
1177
1178
                    else:
                        hidden_states = tensor_parallel.checkpoint(
                            custom(l, l + 1),
                            self.distribute_saved_activations,
Mostofa Patwary's avatar
Mostofa Patwary committed
1179
1180
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1181
                else:
1182
1183
                    if self.transformer_impl == 'transformer_engine':
                        hidden_states = custom(l, l + 1, is_transformer_engine=True)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1184
1185
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1186
1187
                    else:
                        hidden_states = custom(l, l + 1)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1188
1189
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1190
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1191
            raise ValueError("Invalid activation recompute method.")
1192
1193
1194

        return hidden_states

1195
    def set_input_tensor(self, input_tensor):
1196
1197
1198
1199
1200
1201
1202
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
1203
1204
        self.input_tensor = input_tensor

1205
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1206
                encoder_output=None, enc_dec_attn_mask=None,
Mostofa Patwary's avatar
Mostofa Patwary committed
1207
                inference_params=None, rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1208
1209
        # hidden_states: [s, b, h]

1210
        # Checks.
mshoeybi's avatar
mshoeybi committed
1211
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1212
            assert self.recompute_granularity is None, \
1213
                'inference does not work with activation checkpointing'
1214

1215
        if not self.pre_process:
1216
            # See set_input_tensor()
1217
            hidden_states = self.input_tensor
1218

1219
1220
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
1221
1222
1223
1224
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
1225
1226
1227
1228
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
1229
        #
1230
1231
1232
1233
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
1234
        hidden_states = core.utils.make_viewless_tensor(
1235
            hidden_states,
1236
1237
            requires_grad=True,
            keep_graph=True,
1238
1239
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1240
        if self.sequence_parallel:
1241
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
1242
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1243
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1244
1245

        with rng_context:
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
            # The fp8_autocast context manager is a no-op when enabled=True
            # The if...else serves to short circuit name resolution for fp8_autocast
            with transformer_engine.pytorch.fp8_autocast(
                enabled=self.use_fp8,
                fp8_recipe=self.fp8_recipe,
                fp8_group=self.fp8_group
            ) if self.use_fp8 else nullcontext():
                # Determine if the current iteration is first microbatch
                if self.num_microbatches_in_previous_step != get_num_microbatches():
                    self.microbatch_count = 0 # Reset count on new batch size rampup interval
                self.num_microbatches_in_previous_step = get_num_microbatches()
                is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

                # Forward pass.
                if self.recompute_granularity == 'full':
                    hidden_states = self._checkpointed_forward(hidden_states,
                                                               attention_mask,
                                                               encoder_output,
                                                               enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1265
                                                               rotary_pos_emb,
1266
1267
1268
1269
1270
1271
                                                               is_first_microbatch)
                else:
                    forward_kwargs = {
                        'encoder_output': encoder_output,
                        'enc_dec_attn_mask': enc_dec_attn_mask,
                        'inference_params': inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
1272
                        'rotary_pos_emb': rotary_pos_emb,
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
                    }

                    if self.transformer_impl == 'transformer_engine':
                        forward_kwargs['is_first_microbatch'] = is_first_microbatch
                        forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention

                    for index in range(self.num_layers):
                        layer = self._get_layer(index)

                        hidden_states = layer(
                            hidden_states,
                            attention_mask,
                            **forward_kwargs)

                # Skip counter update for eval and activation checkpointing
                if torch.is_grad_enabled() and self.training:
                    self.microbatch_count += 1
mshoeybi's avatar
mshoeybi committed
1290

1291
        # Final layer norm.
1292
        if self.post_process and self.post_layer_norm:
1293
1294
            hidden_states = self.final_layernorm(hidden_states)

1295
        return hidden_states