transformer.py 73.9 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3

"""Transformer."""
4
from contextlib import nullcontext
liangjing's avatar
v1  
liangjing committed
5
6
import math
import numpy as np
7
import torch
8
import torch.nn.functional as F
Jared Casper's avatar
Jared Casper committed
9
from typing import Optional
10

liangjing's avatar
v1  
liangjing committed
11
from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
12
from .module import MegatronModule
13
from megatron.core import mpu, tensor_parallel
14
from megatron.core.enums import ModelType
15
from megatron.model import LayerNorm
16
from megatron.model.enums import AttnMaskType, LayerType, AttnType
17
18
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
liangjing's avatar
v1  
liangjing committed
19
from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb
20
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
21

22
23
24
25
26
27
28
29
try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
liangjing's avatar
v1  
liangjing committed
30
31
32
33
    try:
        from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
    except ImportError:
        flash_attn_unpadded_func = None
34

35
36
37
38
39
40
41
42
43
44
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
45
    Transformer takes input of size [s, b, h] and returns a
46
47
48
49
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

50
class DropPath(MegatronModule):
51
    """Drop paths (Stochastic Depth) per sample
52
53
54
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
55
    def __init__(self, drop_prob=0.):
56
57
58
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
59
    def forward(self, hidden_state):
60
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
            return hidden_state
62
63
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
64
65
        # hidden_state: [s, b, h]
        shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
66
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
67
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
68
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
        output = hidden_state.div(keep_prob) * random_tensor
70
71
        return output

72
73
74
75
76
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
77
    state back into h hidden dimension.
78
79
    """

liangjing's avatar
v1  
liangjing committed
80
    def __init__(self, config):
81
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
82
        args = get_args()
83

liangjing's avatar
v1  
liangjing committed
84
85
86
87
88
        self.add_bias = config.add_bias_linear

        ffn_hidden_size = config.ffn_hidden_size
        if config.gated_linear_unit:
            ffn_hidden_size *= 2
89

90
        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
91
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
liangjing's avatar
v1  
liangjing committed
92
93
94
95
            config.hidden_size,
            ffn_hidden_size,
            config=config,
            init_method=config.init_method,
96
            bias=self.add_bias,
97
            gather_output=False,
98
            skip_bias_add=True,
liangjing's avatar
v1  
liangjing committed
99
        )
100

101
102
103
104
        self.bias_gelu_fusion = False
        self.activation_func = None
        self.swiglu = args.swiglu

105
106
107
108
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
109
110
111
112
113
114
115
116
117
118
119
120
        elif args.swiglu:
            def swiglu(x):
                x = torch.chunk(x, 2, dim=-1)
                return F.silu(x[0]) * x[1]
            self.activation_func = swiglu
        elif args.squared_relu:
            def squared_relu(x):
                return torch.pow(F.relu(x), 2)
            self.activation_func = squared_relu
        else:
            self.bias_gelu_fusion = args.bias_gelu_fusion
            self.activation_func = F.gelu
121
122

        # Project back to h.
123
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
liangjing's avatar
v1  
liangjing committed
124
125
126
127
            config.ffn_hidden_size,
            config.hidden_size,
            config=config,
            init_method=config.output_layer_init_method,
128
            bias=self.add_bias,
liangjing's avatar
v1  
liangjing committed
129
130
            input_is_parallel=True
        )
131

132
133
    def forward(self, hidden_states):

134
135
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
136

137
        if self.bias_gelu_fusion:
138
139
140
            assert self.add_bias is True
            assert self.activation_func == F.gelu
            intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
141
        else:
Jared Casper's avatar
Jared Casper committed
142
            if bias_parallel is not None:
143
144
                intermediate_parallel = intermediate_parallel + bias_parallel
            intermediate_parallel = self.activation_func(intermediate_parallel)
145
146
147
148

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
149

rprenger's avatar
rprenger committed
150
151
152
153
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
liangjing's avatar
v1  
liangjing committed
154
    def __init__(self, config):
rprenger's avatar
rprenger committed
155
156
        super(SwitchMLP, self).__init__()
        args = get_args()
liangjing's avatar
v1  
liangjing committed
157
        self.router = torch.nn.Linear(config.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
158
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
159
        for i in range(args.num_experts):
liangjing's avatar
v1  
liangjing committed
160
            self.experts.append(ParallelMLP(config))
161

rprenger's avatar
rprenger committed
162
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
163
164
165
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
166
167
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
168
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
169
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
170
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
171

rprenger's avatar
rprenger committed
172
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173
        # Converting [s, b, h] to [s*b, h].
174
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
175
176
177
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
178
179
180

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
181
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
182

rprenger's avatar
rprenger committed
183
        for expert_num, expert in enumerate(self.experts):
184
185
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
186
            output, output_bias = expert(hidden)
liangjing's avatar
v1  
liangjing committed
187
188
189
            if output_bias is not None:
                output_bias = output_bias.expand_as(output)
                output_bias_total[local_indices,:] = output_bias
190
191
            output_total[local_indices,:] = output

rprenger's avatar
rprenger committed
192
        output_total = output_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
193
        output_total = output_total.view(s, b, h)
liangjing's avatar
v1  
liangjing committed
194
195
196
197
198
        if output_bias is not None:
            output_bias_total = output_bias_total*max_prob
            output_bias_total = output_bias_total.view(s, b, h)
        else:
            output_bias_total = None
rprenger's avatar
rprenger committed
199
200

        return output_total, output_bias_total
201

202
203

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
204

liangjing's avatar
v1  
liangjing committed
205
    def __init__(self, layer_number, config,
206
207
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
liangjing's avatar
v1  
liangjing committed
208
209
        self.fp16 = config.fp16
        self.bf16 = config.bf16
210

liangjing's avatar
v1  
liangjing committed
211
212
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
213
214
215
216
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
liangjing's avatar
v1  
liangjing committed
217
        self.sequence_parallel = config.sequence_parallel
218

liangjing's avatar
v1  
liangjing committed
219
        projection_size = config.kv_channels * config.num_attention_heads
220
221

        # Per attention head and per partition values.
222
        world_size = mpu.get_tensor_model_parallel_world_size()
223
224
225
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
226
            projection_size, config.num_attention_heads)
227
        self.num_attention_heads_per_partition = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
228
            config.num_attention_heads, world_size)
229
230
231
232
233
234
235
236
237
238

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
liangjing's avatar
v1  
liangjing committed
239
            config.masked_softmax_fusion,
240
241
242
243
244
245
246
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
liangjing's avatar
v1  
liangjing committed
247
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
248

249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
liangjing's avatar
v1  
liangjing committed
263
264
        query_layer = query_layer.reshape(output_size[2],
                                          output_size[0] * output_size[1], -1)
265
266
267
268
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
269
        # preallocting input tensor: [b * np, sq, sk]
270
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
271
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
272
            query_layer.dtype, "mpu")
273
274
275

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
276
            matmul_input_buffer,
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
294
        if not self.sequence_parallel:
295
            with tensor_parallel.get_cuda_rng_tracker().fork():
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class FlashSelfAttention(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
                                                      'e.g., with pip install flash-attn')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
Jimmy Zhang's avatar
Jimmy Zhang committed
364
365
366

        assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
        assert all((i.is_cuda for i in (q,k,v)))
Jimmy Zhang's avatar
Jimmy Zhang committed
367
368

        batch_size, seqlen_q = q.shape[0], q.shape[1]
Jimmy Zhang's avatar
Jimmy Zhang committed
369
        seqlen_k = k.shape[1]
Jimmy Zhang's avatar
Jimmy Zhang committed
370

Jimmy Zhang's avatar
Jimmy Zhang committed
371
372
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
Jimmy Zhang's avatar
Jimmy Zhang committed
373
374
                                    device=q.device)

Jimmy Zhang's avatar
Jimmy Zhang committed
375
376
377
378
379
380
        if self.training:
            # during training q,k,v always have same seqlen
            assert seqlen_k == seqlen_q

            is_causal = self.causal
            cu_seqlens_k = cu_seqlens_q
liangjing's avatar
v1  
liangjing committed
381
            dropout_p = self.dropout_p
Jimmy Zhang's avatar
Jimmy Zhang committed
382
        else:
Jimmy Zhang's avatar
Jimmy Zhang committed
383
            # turn off FA causal mask after first inference autoregressive iteration
Jimmy Zhang's avatar
Jimmy Zhang committed
384
            # only on first autoregressive step q,k,v have same seqlen
Jimmy Zhang's avatar
Jimmy Zhang committed
385
386
            is_causal = seqlen_q == seqlen_k
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
Jimmy Zhang's avatar
Jimmy Zhang committed
387
                        device=q.device)
liangjing's avatar
v1  
liangjing committed
388
            dropout_p = 0
Jimmy Zhang's avatar
Jimmy Zhang committed
389

Jimmy Zhang's avatar
Jimmy Zhang committed
390
391
        output = flash_attn_unpadded_func(
            q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
liangjing's avatar
v1  
liangjing committed
392
            dropout_p,
Jimmy Zhang's avatar
Jimmy Zhang committed
393
394
            softmax_scale=self.softmax_scale, causal=is_causal
        )
Jimmy Zhang's avatar
Jimmy Zhang committed
395

396
397
398
399
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output


400
class ParallelAttention(MegatronModule):
401
402
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
403
    Self-attention layer takes input with size [s, b, h]
404
405
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
406

liangjing's avatar
v1  
liangjing committed
407
    def __init__(self, config, layer_number,
408
409
410
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
411
        args = get_args()
412
        self.layer_number = max(1, layer_number)
413
414
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
liangjing's avatar
v1  
liangjing committed
415
416
417
418
419
420
421
422
423
424
425
        self.params_dtype = config.params_dtype
        self.sequence_parallel = config.sequence_parallel

        self.group_query_attention = args.group_query_attention
        self.num_query_groups = args.num_query_groups

        query_projection_size = config.kv_channels * config.num_attention_heads
        if self.group_query_attention:
            kv_projection_size = args.kv_channels * args.num_query_groups
        else:
            kv_projection_size = args.kv_channels * args.num_attention_heads
426

liangjing's avatar
v1  
liangjing committed
427
428
429
        self.use_flash_attn = args.use_flash_attn \
            and attention_type == AttnType.self_attn \
            and self.attn_mask_type == AttnMaskType.causal
430
431
432
433
434
435
436
437
438
439
        if self.use_flash_attn:
            if flash_attn_unpadded_func is None:
                raise ImportError('FlashAttention is not installed, please install with '
                                  'pip install flash-attn')
            assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
                                                          'self-attention for now')
            assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
                                                                'supports causal mask for now')
            if rearrange is None:
                raise ImportError('einops is not installed, please install with pip install einops')
440

441
        # Per attention head and per partition values.
442
        world_size = mpu.get_tensor_model_parallel_world_size()
443
        self.hidden_size_per_attention_head = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
444
            query_projection_size, config.num_attention_heads)
445
        self.num_attention_heads_per_partition = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
446
447
448
449
450
451
452
453
454
455
            config.num_attention_heads, world_size)

        if self.group_query_attention:
            if args.num_query_groups % world_size != 0:
                raise NotImplementedError('Currently the num_query_groups should be '
                                          'a multiple of the tensor parallel size')
            self.num_query_groups_per_partition = core.utils.divide(
                        args.num_query_groups, world_size)
        else:
            self.num_query_groups_per_partition = self.num_attention_heads_per_partition
456
457

        # Strided linear layer.
458
        if attention_type == AttnType.self_attn:
459
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
liangjing's avatar
v1  
liangjing committed
460
461
462
463
                config.hidden_size,
                query_projection_size + 2 * kv_projection_size,
                config=config,
                init_method=config.init_method,
464
                bias=args.add_bias_linear,
liangjing's avatar
v1  
liangjing committed
465
                gather_output=False)
466
467
468
        else:
            assert attention_type == AttnType.cross_attn

liangjing's avatar
v1  
liangjing committed
469
470
471
            if self.group_query_attention:
                raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
            assert query_projection_size == kv_projection_size
472

liangjing's avatar
v1  
liangjing committed
473
474
475
476
477
478
479
            self.query = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                query_projection_size,
                config=config,
                init_method=config.init_method,
                bias=config.add_bias_linear,
                gather_output=False)
480

liangjing's avatar
v1  
liangjing committed
481
482
483
484
485
486
487
488
489
            self.key_value = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                2 * kv_projection_size,
                config=config,
                init_method=config.init_method,
                bias=config.add_bias_linear,
                gather_output=False)

        self.core_attention = CoreAttention(self.layer_number, config,
490
                                            self.attn_mask_type)
liangjing's avatar
v1  
liangjing committed
491
        self.checkpoint_core_attention = config.recompute_granularity == 'selective'
492

493
494
        if self.use_flash_attn:
            self.core_attention_flash = FlashSelfAttention(
liangjing's avatar
v1  
liangjing committed
495
                causal=True, attention_dropout=config.attention_dropout
496
497
            )

498
        # Output.
499
        self.dense = tensor_parallel.RowParallelLinear(
liangjing's avatar
v1  
liangjing committed
500
501
502
503
            query_projection_size,
            config.hidden_size,
            config=config,
            init_method=config.output_layer_init_method,
504
            bias=args.add_bias_linear,
505
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
506
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
507

508
    def _checkpointed_attention_forward(self, query_layer, key_layer,
Mostofa Patwary's avatar
Mostofa Patwary committed
509
510
                                        value_layer, attention_mask,
                                        rotary_pos_emb=None):
511
512
513
514
515
516
517
518
519
520
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

Mostofa Patwary's avatar
Mostofa Patwary committed
521
522
523
        q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
            else rotary_pos_emb

524
        hidden_states = tensor_parallel.checkpoint(
525
            custom_forward,
Mostofa Patwary's avatar
Mostofa Patwary committed
526
527
            False, query_layer, key_layer, value_layer, attention_mask,
            q_pos_emb, k_pos_emb)
528
529

        return hidden_states
530

liangjing's avatar
v1  
liangjing committed
531
    def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
532
533
534
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
liangjing's avatar
v1  
liangjing committed
535
            num_attention_heads,
536
537
538
539
540
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
541
542
                encoder_output=None, inference_params=None,
                rotary_pos_emb=None):
543
        # hidden_states: [sq, b, h]
544

545
546
547
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
Mostofa Patwary's avatar
Mostofa Patwary committed
548
        is_first_step = False
mshoeybi's avatar
mshoeybi committed
549
        if inference_params:
550
            if self.layer_number not in inference_params.key_value_memory_dict:
liangjing's avatar
v1  
liangjing committed
551
                inf_max_seq_len = inference_params.max_sequence_length
mshoeybi's avatar
mshoeybi committed
552
                inf_max_batch_size = inference_params.max_batch_size
553
                inference_key_memory = self._allocate_memory(
liangjing's avatar
v1  
liangjing committed
554
555
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)
556
                inference_value_memory = self._allocate_memory(
liangjing's avatar
v1  
liangjing committed
557
558
559
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)

560
561
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
Mostofa Patwary's avatar
Mostofa Patwary committed
562
                is_first_step = True
563
564
565
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
566

567
568
569
        # =====================
        # Query, Key, and Value
        # =====================
570
        if self.attention_type == AttnType.self_attn:
liangjing's avatar
v1  
liangjing committed
571
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
572
573
            mixed_x_layer, _ = self.query_key_value(hidden_states)

liangjing's avatar
v1  
liangjing committed
574
575
576
577
578
579
580
581
            # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_query_groups_per_partition,
                (
                    (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
                    * self.hidden_size_per_attention_head
                ),
            )
582
583
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

liangjing's avatar
v1  
liangjing committed
584
            # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
585
            (query_layer,
liangjing's avatar
v1  
liangjing committed
586
587
588
589
590
591
592
593
594
595
596
597
598
599
            key_layer,
            value_layer) = torch.split(
                mixed_x_layer,
                [
                    (
                        self.num_attention_heads_per_partition // self.num_query_groups_per_partition
                        * self.hidden_size_per_attention_head
                    ),
                    self.hidden_size_per_attention_head,
                    self.hidden_size_per_attention_head
                ],
                dim=3)
            # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
            query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
600
601
602
603
604
605
606
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
liangjing's avatar
v1  
liangjing committed
607
                2 * self.hidden_size_per_attention_head)
608
609
610
611
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
liangjing's avatar
v1  
liangjing committed
612
            value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
613
614
615
616
617
618

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
liangjing's avatar
v1  
liangjing committed
619
                self.hidden_size_per_attention_head)
620
            query_layer = query_layer.view(*new_tensor_shape)
621

mshoeybi's avatar
mshoeybi committed
622
623
624
        # ==================================
        # Adjust key and value for inference
        # ==================================
625

Mostofa Patwary's avatar
Mostofa Patwary committed
626
627
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
628
629
630
631
            if isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = rotary_pos_emb
            else:
                rotary_pos_emb = ((rotary_pos_emb,) * 2)
Mostofa Patwary's avatar
Mostofa Patwary committed
632

mshoeybi's avatar
mshoeybi committed
633
        if inference_params:
mshoeybi's avatar
mshoeybi committed
634
635
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
636
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
637
638
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
639
            assert sequence_end <= inference_key_memory.size(0)
640
            # Copy key and values.
641
642
643
644
645
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
646
                :sequence_end, batch_start:batch_end, ...]
647
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
648
                :sequence_end, batch_start:batch_end, ...]
649

Mostofa Patwary's avatar
Mostofa Patwary committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670

            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)

671
672
673
        # ==================================
        # core attention computation
        # ==================================
674

liangjing's avatar
v1  
liangjing committed
675
676
677
678
679
680
681
682
683
684
        # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
        key_layer = key_layer.repeat_interleave(
            self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
            dim = 2
        )
        value_layer = value_layer.repeat_interleave(
            self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
            dim = 2
        )

Mostofa Patwary's avatar
Mostofa Patwary committed
685
686
687
688
689
690
691
692
693
694
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

695
696
697
698
699
700
701
        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
702
        else:
703
704
705
706
707
708
709
710
            q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
                       for x in (query_layer, key_layer, value_layer)]
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
                    context_layer = self.core_attention_flash(q, k, v)
            else:
                context_layer = self.core_attention_flash(q, k, v)
            context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
711
712

        # =================
713
        # Output. [sq, b, h]
714
715
716
        # =================

        output, bias = self.dense(context_layer)
717

718
719
720
        return output, bias


721
def bias_dropout_add(x, bias, residual, prob, training):
Jared Casper's avatar
Jared Casper committed
722
    # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
723
724
725
    if bias is not None:
        x = x + bias
    out = torch.nn.functional.dropout(x, p=prob, training=training)
726
727
728
729
730
731
732
733
734
735
736
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
737
def bias_dropout_add_fused_train(x: torch.Tensor,
Jared Casper's avatar
Jared Casper committed
738
                                 bias: Optional[torch.Tensor],
739
740
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
741
742
743
744
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
745
def bias_dropout_add_fused_inference(x: torch.Tensor,
Jared Casper's avatar
Jared Casper committed
746
                                     bias: Optional[torch.Tensor],
747
748
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
749
    return bias_dropout_add(x, bias, residual, prob, False)
750
751
752
753
754


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
755
    Transformer layer takes input with size [s, b, h] and returns an
756
757
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
758

liangjing's avatar
v1  
liangjing committed
759
    def __init__(self, config,
760
                 layer_number, layer_type=LayerType.encoder,
761
762
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
liangjing's avatar
v1  
liangjing committed
763
                 # retriever=None):
Mohammad's avatar
Mohammad committed
764
        args = get_args()
765
766

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
767
        self.layer_number = layer_number
768
        self.layer_type = layer_type
769
770

        self.apply_residual_connection_post_layernorm \
liangjing's avatar
v1  
liangjing committed
771
            = config.apply_residual_connection_post_layernorm
772

liangjing's avatar
v1  
liangjing committed
773
774
        self.bf16 = config.bf16
        self.fp32_residual_connection = config.fp32_residual_connection
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
775

776
777
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
liangjing's avatar
v1  
liangjing committed
778
779
            config.hidden_size,
            eps=config.layernorm_epsilon,
780
            no_persist_layer_norm=args.no_persist_layer_norm,
liangjing's avatar
v1  
liangjing committed
781
            sequence_parallel=config.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
782
            apply_layernorm_1p=args.apply_layernorm_1p)
783
784

        # Self attention.
785
        self.self_attention = ParallelAttention(
liangjing's avatar
v1  
liangjing committed
786
            config,
787
788
789
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
liangjing's avatar
v1  
liangjing committed
790
791
        self.hidden_dropout = config.hidden_dropout
        self.bias_dropout_fusion = config.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
792
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
793

794
        # Layernorm on the attention output
795
        self.post_attention_layernorm = LayerNorm(
liangjing's avatar
v1  
liangjing committed
796
797
798
799
            config.hidden_size,
            eps=config.layernorm_epsilon,
            no_persist_layer_norm=not config.persist_layer_norm,
            sequence_parallel=config.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
800
            apply_layernorm_1p=args.apply_layernorm_1p)
801

liangjing's avatar
v1  
liangjing committed
802
803
804
805
806
        # Cross attention.
        if self.layer_type in (LayerType.decoder,
                               LayerType.retro_decoder,
                               LayerType.retro_decoder_with_retriever,
                               LayerType.retro_encoder):
807
            self.inter_attention = ParallelAttention(
liangjing's avatar
v1  
liangjing committed
808
                config,
809
810
811
812
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
liangjing's avatar
v1  
liangjing committed
813
814
815
816
                config.hidden_size,
                eps=config.layernorm_epsilon,
                no_persist_layer_norm=not config.persist_layer_norm,
                sequence_parallel=config.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
817
                apply_layernorm_1p=args.apply_layernorm_1p)
818

819
        # MLP
rprenger's avatar
rprenger committed
820
        if args.num_experts is not None:
liangjing's avatar
v1  
liangjing committed
821
            self.mlp = SwitchMLP(config)
rprenger's avatar
rprenger committed
822
        else:
liangjing's avatar
v1  
liangjing committed
823
            self.mlp = ParallelMLP(config)
824

825
826
827
828
829
830
831
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

liangjing's avatar
v1  
liangjing committed
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        if args.retro_add_retriever:
            retro_args = get_retro_args()
            self.retro_num_neighbors = args.retro_num_neighbors
            self.retro_chunk_length = retro_args.retro_gpt_chunk_length
            self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length

        # Retriever (bi-directional transformer with cross attention)
        if layer_type == LayerType.retro_decoder_with_retriever:
            self.retriever = ParallelTransformer(
                config=config,
                model_type=ModelType.retro_encoder,
                self_attn_mask_type=AttnMaskType.padding,
                pre_process=True,
                post_process=False,
            )
            self._retriever_key = 'retriever'
        else:
            self.retriever = None

    def default_decoder_cross_attention(self,
                                        encoder_output,
                                        enc_dec_attn_mask,
                                        layernorm_input,
                                        layernorm_output,
                                        bias_dropout_add_func):
        '''Cross attention for a standard encoder-decoder model.'''

        # Attention.
        attention_output, attention_bias = \
            self.inter_attention(layernorm_output,
                                 enc_dec_attn_mask,
                                 encoder_output=encoder_output)

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        if attention_bias is not None:
            attention_bias = attention_bias.expand_as(residual)

        # Bias-dropout-add.
        with self.bias_dropout_add_exec_handler():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias,
                residual,
                self.hidden_dropout)

        # Layer norm.
        layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

        return layernorm_input, layernorm_output

    def retro_encoder_cross_attention(self,
                                      retriever_output,
                                      layernorm_input,
                                      layernorm_output,
                                      bias_dropout_add_func):
        """Cross attention for Retro encoder.

        Notation:
            ns : Sequence length.
            bs : Batch size.
            d  : Hidden size.
            l  : Number of chunks per sample (i.e., seq_length/chunk_length).
            k  : Number of neighbors.
            r  : Number of retrieved tokens (neighbors + continuation).
        """

        ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]

        # Divide sequence dimension into chunks.
        chunked_outputs = layernorm_output.reshape(self.retro_retrieved_length,
                                                   -1,
                                                   self.retro_num_neighbors,
                                                   d)
        chunked_outputs_before_layer_norm = \
            layernorm_input.reshape(self.retro_retrieved_length, -1,
                                    self.retro_num_neighbors, d) # [r, bs*l, k, d]

        # Per-chunk attention.
        layernorm_inputs = []
        layernorm_outputs = []
        for k in range(self.retro_num_neighbors):

            # Attention.
            chunked_output = chunked_outputs[:,:,k].contiguous()
            attention_output, attention_bias = \
                self.inter_attention(
                    chunked_output, # Q (neighbor embedding)
                    None,
                    encoder_output=retriever_output) # K, V (hidden act)

            # Residual connection.
            if self.apply_residual_connection_post_layernorm:
                residual = chunked_output
            else:
                residual = chunked_outputs_before_layer_norm[:,:,k]

            # Re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    None if attention_bias is None else attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
                layernorm_inputs.append(layernorm_input)

            # Layer norm.
            layernorm_output = \
                self.post_inter_attention_layernorm(layernorm_input)
            layernorm_outputs.append(layernorm_output)

        # Concatenate layer norms.
        # layernorm_input : [r, k * bs * l, d]
        # layernorm_output : [r, k * bs * l, d]
        layernorm_input = \
            torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
        layernorm_output = \
            torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)

        return layernorm_input, layernorm_output

    def retro_decoder_cross_attention(self,
                                      retriever_input,
                                      retriever_output,
                                      retriever_attn_mask,
                                      layernorm_input,
                                      layernorm_output,
                                      inference_params,
                                      bias_dropout_add_func):
        """Cross attention for Retro decoder.

        Notation:
            ns : Sequence length.
            bs : Batch size.
            d  : Hidden size.
            l  : Number of chunks per sample (i.e., seq_length/chunk_length).
            m  : Number of tokens per chunk.
            k  : Number of neighbors.
            r  : Number of retrieved tokens (neighbors + continuation).
        """

        ns, bs, d = layernorm_output.shape
        l = int(np.ceil(ns / self.retro_chunk_length))

        # Retrieve neighbors.
        if self.layer_type == LayerType.retro_decoder_with_retriever:
            first_ns = ns % self.retro_chunk_length
            if first_ns > 0:
                raise Exception("test this case.")
                first_chunk, rest_chunk = \
                    layernorm_output[:first_ns], layernorm_output[first_ns:]
                first_chunk = torch.nn.functional.pad(
                    first_chunk,
                    (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
                    'constant',
                    0)
                chunked_output = \
                    torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
            else:
                chunked_output = layernorm_output # [l * m, bs, d]
            chunked_output = chunked_output \
                .reshape(l, self.retro_chunk_length, bs, d) \
                .permute(1, 2, 0, 3) \
                .reshape(self.retro_chunk_length, bs * l, d) \
                .contiguous()

            # Get Encoder Output
            retriever_output = self.retriever(
                hidden_states=retriever_input,
                attention_mask=retriever_attn_mask,
                retriever_output=chunked_output,
                retriever_attn_mask=retriever_attn_mask,
                inference_params=inference_params) # [r, k * bs * l , d]
            retriever_output = retriever_output.reshape(
                self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]

        # Chunks.
        pad = (ns - 1) % self.retro_chunk_length
        attending_chunks = layernorm_output[pad:]
        padded_chunks = torch.nn.functional.pad(
            attending_chunks,
            (0, 0, 0, 0, 0, self.retro_chunk_length - 1),
            'constant', 0)
        padded_chunked_output = padded_chunks \
            .reshape(l, self.retro_chunk_length, bs, d) \
            .permute(1, 2, 0, 3)
        padded_chunked_output = padded_chunked_output.reshape(
            self.retro_chunk_length, bs * l, d).contiguous()

        # Encoder output.
        attention_output, attention_bias = \
            self.inter_attention(padded_chunked_output,
                                 None,
                                 encoder_output=retriever_output)

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        # Re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                None if attention_bias is None else attention_bias.expand_as(attention_output),
                torch.zeros_like(attention_output),
                self.hidden_dropout)
            layernorm_input = layernorm_input \
                .reshape(self.retro_chunk_length, bs, l, d) \
                .permute(2, 0, 1, 3) # [l, m, bs, d]
            layernorm_input = layernorm_input.reshape(self.retro_chunk_length * l, bs, d)
            layernorm_input = torch.nn.functional.pad(
                layernorm_input,
                (0, 0, 0, 0, pad, 0),
                'constant', 0)[:ns] # [ns, b, d]
            layernorm_input = layernorm_input + residual

        # Layer norm post the decoder attention
        layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

        return retriever_output, layernorm_input, layernorm_output

1059
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1060
                encoder_output=None, enc_dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
1061
1062
1063
1064
1065
                retriever_input=None,
                retriever_output=None,
                retriever_attn_mask=None,
                inference_params=None,
                rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1066
        # hidden_states: [s, b, h]
1067

1068
        # Layer norm at the beginning of the transformer layer.
1069
        layernorm_output = self.input_layernorm(hidden_states)
liangjing's avatar
v1  
liangjing committed
1070

1071
        # Self attention.
1072
        attention_output, attention_bias = \
1073
1074
1075
            self.self_attention(
                layernorm_output,
                attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1076
                inference_params=inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
1077
                rotary_pos_emb=rotary_pos_emb)
1078

1079
1080
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
1081
1082
1083
1084
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1085
        if self.drop_path is None:
1086
1087
1088
1089
1090
1091
1092
1093
1094
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
1095
            else:
1096
                bias_dropout_add_func = get_bias_dropout_add(self.training)
1097

1098
1099
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)
1100
            with self.bias_dropout_add_exec_handler():
1101
1102
                layernorm_input = bias_dropout_add_func(
                    attention_output,
1103
                    attention_bias,
1104
1105
1106
1107
1108
1109
1110
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
1111

1112
1113
1114
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

liangjing's avatar
v1  
liangjing committed
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
        # Cross attention.
        if self.layer_type == LayerType.encoder:
            pass
        elif self.layer_type == LayerType.decoder:
            layernorm_input, layernorm_output = \
                self.default_decoder_cross_attention(
                    encoder_output,
                    enc_dec_attn_mask,
                    layernorm_input,
                    layernorm_output,
                    bias_dropout_add_func)
        elif self.layer_type == LayerType.retro_encoder:
            layernorm_input, layernorm_output = \
                self.retro_encoder_cross_attention(
                    retriever_output,
                    layernorm_input,
                    layernorm_output,
                    bias_dropout_add_func)
        elif self.layer_type in (LayerType.retro_decoder,
                                 LayerType.retro_decoder_with_retriever):
            retriever_output, layernorm_input, layernorm_output = \
                self.retro_decoder_cross_attention(
                    retriever_input,
                    retriever_output,
                    retriever_attn_mask,
                    layernorm_input,
                    layernorm_output,
                    inference_params,
                    bias_dropout_add_func)
        else:
            raise Exception("Unsupported layer type, '%s'." %
                            self.layer_type.name)
1147

1148
        # MLP.
1149
        mlp_output, mlp_bias = self.mlp(layernorm_output)
1150

1151
1152
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
1153
            residual = layernorm_output
1154
        else:
1155
1156
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1157
        if self.drop_path is None:
1158
1159
            if mlp_bias is not None:
                mlp_bias = mlp_bias.expand_as(residual)
1160
            with self.bias_dropout_add_exec_handler():
1161
1162
                output = bias_dropout_add_func(
                    mlp_output,
1163
                    mlp_bias,
1164
1165
                    residual,
                    self.hidden_dropout)
1166
1167
1168
1169
1170
1171
1172

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
1173
1174
1175
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
1176

1177
        else:
1178
1179
1180
            if mlp_bias is not None:
                mlp_output = mlp_output + mlp_bias
            out = torch.nn.functional.dropout(mlp_output,
1181
1182
1183
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
1184

liangjing's avatar
v1  
liangjing committed
1185
1186
1187
1188
        if self.layer_type == LayerType.retro_decoder_with_retriever:
            return output, retriever_output
        else:
            return output
1189
1190


1191
1192
1193
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
1194
    The sole purpose of this layer is for when a standalone embedding layer
1195
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
1196
1197
1198
1199
1200
1201
1202
1203
1204
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


liangjing's avatar
v1  
liangjing committed
1217
def _get_num_layers(args, model_type, is_decoder=False):
1218
    """Compute the number of transformer layers resident on the current rank."""
liangjing's avatar
v1  
liangjing committed
1219
1220
1221
1222
    is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
    if model_type == ModelType.retro_encoder:
        num_layers = args.retro_encoder_layers
    elif mpu.get_pipeline_model_parallel_world_size() > 1:
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
                if args.standalone_embedding_stage else
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
Jared Casper's avatar
Jared Casper committed
1236
1237
1238
1239
            assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
                    'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
            assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
                    'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
Jared Casper's avatar
Jared Casper committed
1240
            if mpu.is_pipeline_stage_before_split():
1241
1242
1243
                num_layers = (
                    0
                    if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
1244
                    and mpu.get_pipeline_model_parallel_rank() == 0 else
Jared Casper's avatar
Jared Casper committed
1245
                    args.encoder_num_layers // num_ranks_in_encoder
1246
1247
                )
            else:
Jared Casper's avatar
Jared Casper committed
1248
                num_layers = args.decoder_num_layers // num_ranks_in_decoder
1249
        else:
Jared Casper's avatar
Jared Casper committed
1250
            assert args.num_layers == args.encoder_num_layers
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
            num_layers = (
                0
                if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
1261
                and mpu.get_pipeline_model_parallel_rank() == 0 else
1262
1263
1264
                args.num_layers // args.transformer_pipeline_model_parallel_size
            )
    else:
Jared Casper's avatar
Jared Casper committed
1265
1266
1267
1268
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
1269
1270
1271
    return num_layers


liangjing's avatar
v1  
liangjing committed
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
                    layer_number):
    args = get_args()
    if args.retro_add_retriever and layer_number in retro_layer_numbers:
        if model_type == ModelType.retro_decoder:
            return LayerType.retro_decoder_with_retriever \
                if layer_number == retro_layer_numbers[0] \
                   else LayerType.retro_decoder
        elif model_type == ModelType.retro_encoder:
            return LayerType.retro_encoder
        else:
            raise Exception("Unsupported model type, '%s'." % model_type)
    else:
        return default_layer_type


1288
1289
1290
class ParallelTransformer(MegatronModule):
    """Transformer class."""

liangjing's avatar
v1  
liangjing committed
1291
1292
    def __init__(self, config,
                 model_type, layer_type=LayerType.encoder,
1293
                 self_attn_mask_type=AttnMaskType.padding,
1294
                 post_layer_norm=True,
liangjing's avatar
v1  
liangjing committed
1295
1296
                 pre_process=True,
                 post_process=True,
1297
                 drop_path_rate=0.0):
1298
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
1299
        args = get_args()
1300

1301
        self.layer_type = layer_type
liangjing's avatar
v1  
liangjing committed
1302
1303
1304
        self.model_type = model_type
        self.bf16 = config.bf16
        self.fp32_residual_connection = config.fp32_residual_connection
1305
        self.post_layer_norm = post_layer_norm
1306
1307
1308
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
1309
        self.drop_path_rate = drop_path_rate
1310
        self.transformer_impl = args.transformer_impl
liangjing's avatar
v1  
liangjing committed
1311
        self.retro_add_retriever = args.retro_add_retriever
1312

1313
        # Store activation checkpoiting flag.
liangjing's avatar
v1  
liangjing committed
1314
1315
1316
        self.recompute_granularity = config.recompute_granularity
        self.recompute_method = config.recompute_method
        self.recompute_num_layers = config.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1317
        self.distribute_saved_activations = \
liangjing's avatar
v1  
liangjing committed
1318
            config.distribute_saved_activations and not config.sequence_parallel
1319

liangjing's avatar
v1  
liangjing committed
1320
        self.sequence_parallel = config.sequence_parallel
1321

1322
        # Transformer Engine Init.
liangjing's avatar
v1  
liangjing committed
1323
1324
1325
        self.transformer_engine_v_0_10 = False
        self.transformer_engine_v_0_11 = False
        self.transformer_engine_v_0_8 = False
1326
1327
1328
        if self.transformer_impl == 'transformer_engine':
            global transformer_engine
            import transformer_engine
liangjing's avatar
v1  
liangjing committed
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
            from importlib.metadata import version
            from pkg_resources import packaging

            te_version = packaging.version.Version(version("transformer-engine"))
            if te_version >= packaging.version.Version("0.8.0"):
                self.transformer_engine_v_0_8 = True
            if te_version >= packaging.version.Version("0.10.0"):
                self.transformer_engine_v_0_10 = True
            if te_version >= packaging.version.Version("0.11.0"):
                self.transformer_engine_v_0_11 = True

            del version, packaging

            assert not args.squared_relu, "TransformerEngine does not support squared relu activation."

        self.use_fp8 = args.fp8 is not None
1345
        self.fp8_recipe = None
1346
        self.fp8_group = None
1347
        if self.use_fp8:
liangjing's avatar
v1  
liangjing committed
1348
1349
1350
1351
            assert args.transformer_impl == 'transformer_engine', \
                'transformer-engine required for fp8 training and inference'
            self.fp8_group = mpu.get_amax_reduction_group()
            if args.fp8 == "e4m3":
1352
                fp8_format = transformer_engine.common.recipe.Format.E4M3
liangjing's avatar
v1  
liangjing committed
1353
            elif args.fp8 == "hybrid":
1354
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
liangjing's avatar
v1  
liangjing committed
1355
1356
            else:
                raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
            self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
                margin=args.fp8_margin,
                interval=args.fp8_interval,
                fp8_format=fp8_format,
                amax_history_len=args.fp8_amax_history_len,
                amax_compute_algo=args.fp8_amax_compute_algo,
                override_linear_precision=(False, False, not args.fp8_wgrad),
            )

        self.num_microbatches_in_previous_step = -1
        self.microbatch_count = 0
liangjing's avatar
v1  
liangjing committed
1368
        self.checkpoint_core_attention = config.recompute_granularity == 'selective'
1369

1370
        # Number of layers.
liangjing's avatar
v1  
liangjing committed
1371
1372
1373
1374
1375
1376
        self.num_layers = _get_num_layers(args, model_type,
                                          layer_type==LayerType.decoder)

        self.drop_path_rates = [
            rate.item() for rate in
            torch.linspace(0, self.drop_path_rate, config.num_layers)]
Mohammad's avatar
Mohammad committed
1377

liangjing's avatar
v1  
liangjing committed
1378
1379
1380
1381
1382
1383
1384
        self.retro_layer_numbers = None
        if model_type == ModelType.retro_decoder:
            retro_layer_start = 6 if config.num_layers <= 15 else 9
            self.retro_layer_numbers = \
                np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
        if model_type == ModelType.retro_encoder:
            self.retro_layer_numbers = [1]
1385

Mohammad's avatar
Mohammad committed
1386
        # Transformer layers.
liangjing's avatar
v1  
liangjing committed
1387
1388
1389
1390
1391
        if args.retro_add_retriever:
            assert self.recompute_granularity != 'full', \
                "Full recompute not supported for Retro."
            assert args.transformer_impl == 'local', \
                "Transformer engine does not support Retro layers."
Mohammad's avatar
Mohammad committed
1392
        def build_layer(layer_number):
1393
            if args.transformer_impl == 'local':
liangjing's avatar
v1  
liangjing committed
1394
1395
1396
                current_layer_type = _get_layer_type(
                    model_type, layer_type, self.retro_layer_numbers,
                    layer_number)
1397
                return ParallelTransformerLayer(
liangjing's avatar
v1  
liangjing committed
1398
                    config,
1399
                    layer_number,
liangjing's avatar
v1  
liangjing committed
1400
                    layer_type=current_layer_type,
1401
1402
1403
                    self_attn_mask_type=self_attn_mask_type,
                    drop_path_rate=self.drop_path_rates[layer_number - 1])
            else:
liangjing's avatar
v1  
liangjing committed
1404
1405
1406
1407
1408
1409
1410
1411
                # This argument is only available from TE v0.10 onwards.
                extra_transformer_engine_kwargs = {}
                if self.transformer_engine_v_0_8:
                    extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
                if self.transformer_engine_v_0_10:
                    extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
                if self.transformer_engine_v_0_11:
                    extra_transformer_engine_kwargs["normalization"] = args.normalization
1412
                return transformer_engine.pytorch.TransformerLayer(
liangjing's avatar
v1  
liangjing committed
1413
1414
1415
1416
1417
1418
1419
1420
                    config.hidden_size,
                    config.ffn_hidden_size,
                    config.num_attention_heads,
                    layernorm_epsilon=config.layernorm_epsilon,
                    hidden_dropout=config.hidden_dropout,
                    attention_dropout=config.attention_dropout,
                    init_method=config.init_method,
                    output_layer_init_method=config.output_layer_init_method,
1421
                    layer_number=layer_number,
liangjing's avatar
v1  
liangjing committed
1422
                    kv_channels=config.kv_channels,
1423
1424
1425
                    self_attn_mask_type=self_attn_mask_type.name,
                    tp_group=mpu.get_tensor_model_parallel_group(),
                    get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
liangjing's avatar
v1  
liangjing committed
1426
1427
1428
                    fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
                    apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
                    attention_softmax_in_fp32=config.attention_softmax_in_fp32,
1429
1430
                    seq_length=args.seq_length,
                    micro_batch_size=args.micro_batch_size,
liangjing's avatar
v1  
liangjing committed
1431
1432
1433
                    sequence_parallel=config.sequence_parallel,
                    params_dtype=config.params_dtype,
                    apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
1434
1435
1436
1437
                    output_layernorm=False,
                    layer_type="encoder",
                    drop_path_rate=self.drop_path_rates[layer_number - 1],
                    set_parallel_mode=True,
liangjing's avatar
v1  
liangjing committed
1438
1439
                    fuse_qkv_params=True,
                    **extra_transformer_engine_kwargs)
1440

liangjing's avatar
v1  
liangjing committed
1441
1442
        if config.virtual_pipeline_model_parallel_size is not None:
            assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
1443
1444
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1445
            assert args.model_type != ModelType.encoder_and_decoder
1446
1447
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
liangjing's avatar
v1  
liangjing committed
1448
            self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
1449
1450
1451
1452
1453
1454
1455
1456
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
1457
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
liangjing's avatar
v1  
liangjing committed
1458
                config.num_layers // config.virtual_pipeline_model_parallel_size) + \
1459
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
1460
        else:
1461
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1462
            if args.model_type == ModelType.encoder_and_decoder and \
1463
1464
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1465
1466
1467
1468
1469
1470
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
1471
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
1472

1473
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
1474
            # When a standalone embedding stage is used (e.g.,
1475
            # args.standalone_embedding_stage == True), virtual pipeline ranks
1476
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
1477
1478
1479
1480
1481
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
1482
1483
1484
1485
1486
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
1487

liangjing's avatar
v1  
liangjing committed
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
            # Update dropout rate for Retro encoder.
            if model_type == ModelType.retro_encoder:
                for layer in self.layers:
                    if layer.self_attention.use_flash_attn:
                        layer.self_attention.core_attention_flash.dropout_p = \
                            torch.nn.Dropout(args.retro_encoder_attention_dropout)
                    else:
                        layer.self_attention.core_attention.attention_dropout.p =\
                            args.retro_encoder_attention_dropout
                    layer.hidden_dropout = args.retro_encoder_hidden_dropout

1499
        if self.post_process and self.post_layer_norm:
1500
1501
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
liangjing's avatar
v1  
liangjing committed
1502
1503
                config.hidden_size,
                eps=config.layernorm_epsilon,
1504
                no_persist_layer_norm=args.no_persist_layer_norm,
liangjing's avatar
v1  
liangjing committed
1505
                sequence_parallel=config.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
1506
                apply_layernorm_1p=args.apply_layernorm_1p)
1507

Mohammad's avatar
Mohammad committed
1508
    def _get_layer(self, layer_number):
1509
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
1510

1511
    def _checkpointed_forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1512
1513
                              encoder_output, enc_dec_attn_mask,
                              rotary_pos_emb, is_first_microbatch):
1514
        """Forward method with activation checkpointing."""
liangjing's avatar
v1  
liangjing committed
1515
        def custom(start, end):
1516
            def custom_forward(*args, **kwargs):
1517
                x_, *args = args
Mohammad's avatar
Mohammad committed
1518
1519
                for index in range(start, end):
                    layer = self._get_layer(index)
1520
                    x_ = layer(x_, *args, **kwargs)
1521
                return x_
liangjing's avatar
v1  
liangjing committed
1522
1523
1524
1525
1526
1527
1528
            return custom_forward

        te_forward_kwargs = {}
        if self.transformer_impl == 'transformer_engine':
            te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
            if self.transformer_engine_v_0_10:
                te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
1529

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1530
        if self.recompute_method == 'uniform':
liangjing's avatar
v1  
liangjing committed
1531
1532
            # Uniformly divide the total number of Transformer layers and
            # checkpoint the input activation of each divided chunk.
1533
1534
1535
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
1536
                if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1537
1538
                    hidden_states = transformer_engine.pytorch.checkpoint(
                        custom(l, l + self.recompute_num_layers),
1539
1540
1541
                        self.distribute_saved_activations,
                        tensor_parallel.get_cuda_rng_tracker,
                        mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1542
                        hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1543
                        enc_dec_attn_mask, **te_forward_kwargs)
1544
1545
1546
1547
                else:
                    hidden_states = tensor_parallel.checkpoint(
                        custom(l, l + self.recompute_num_layers),
                        self.distribute_saved_activations,
liangjing's avatar
v1  
liangjing committed
1548
1549
1550
                        hidden_states, attention_mask,
                        encoder_output, enc_dec_attn_mask,
                        None, None, None, None, rotary_pos_emb)
1551

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1552
                l += self.recompute_num_layers
1553

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1554
        elif self.recompute_method == 'block':
1555
1556
1557
1558
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1559
                if l < self.recompute_num_layers:
1560
                    if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1561
1562
                        hidden_states = transformer_engine.pytorch.checkpoint(
                            custom(l, l + 1),
1563
1564
1565
                            self.distribute_saved_activations,
                            tensor_parallel.get_cuda_rng_tracker,
                            mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1566
                            hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1567
                            enc_dec_attn_mask, **te_forward_kwargs)
1568
1569
1570
1571
                    else:
                        hidden_states = tensor_parallel.checkpoint(
                            custom(l, l + 1),
                            self.distribute_saved_activations,
liangjing's avatar
v1  
liangjing committed
1572
1573
1574
                            hidden_states, attention_mask,
                            encoder_output, enc_dec_attn_mask,
                            None, None, None, None, rotary_pos_emb)
1575
                else:
1576
                    if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1577
                        hidden_states = custom(l, l + 1)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1578
                            hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1579
                            enc_dec_attn_mask, **te_forward_kwargs)
1580
1581
                    else:
                        hidden_states = custom(l, l + 1)(
liangjing's avatar
v1  
liangjing committed
1582
1583
1584
                            hidden_states, attention_mask,
                            encoder_output, enc_dec_attn_mask,
                            None, None, None, None, rotary_pos_emb)
1585
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1586
            raise ValueError("Invalid activation recompute method.")
1587
1588
1589

        return hidden_states

1590
    def set_input_tensor(self, input_tensor):
1591
1592
1593
1594
1595
1596
1597
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
1598
1599
        self.input_tensor = input_tensor

1600
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1601
                encoder_output=None, enc_dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
1602
1603
1604
1605
1606
                retriever_input=None,
                retriever_output=None,
                retriever_attn_mask=None,
                inference_params=None,
                rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1607
1608
        # hidden_states: [s, b, h]

1609
        # Checks.
mshoeybi's avatar
mshoeybi committed
1610
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1611
            assert self.recompute_granularity is None, \
1612
                'inference does not work with activation checkpointing'
1613

1614
        if not self.pre_process:
1615
            # See set_input_tensor()
1616
            hidden_states = self.input_tensor
1617

1618
1619
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
1620
1621
1622
1623
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
1624
1625
1626
1627
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
1628
        #
1629
1630
1631
1632
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
1633
        hidden_states = core.utils.make_viewless_tensor(
1634
            hidden_states,
1635
1636
            requires_grad=True,
            keep_graph=True,
1637
1638
        )

liangjing's avatar
v1  
liangjing committed
1639
        # RNG context.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1640
        if self.sequence_parallel:
1641
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
1642
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1643
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1644

liangjing's avatar
v1  
liangjing committed
1645
        # Forward layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1646
        with rng_context:
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
            # The fp8_autocast context manager is a no-op when enabled=True
            # The if...else serves to short circuit name resolution for fp8_autocast
            with transformer_engine.pytorch.fp8_autocast(
                enabled=self.use_fp8,
                fp8_recipe=self.fp8_recipe,
                fp8_group=self.fp8_group
            ) if self.use_fp8 else nullcontext():
                # Determine if the current iteration is first microbatch
                if self.num_microbatches_in_previous_step != get_num_microbatches():
                    self.microbatch_count = 0 # Reset count on new batch size rampup interval
                self.num_microbatches_in_previous_step = get_num_microbatches()
                is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

                # Forward pass.
                if self.recompute_granularity == 'full':
                    hidden_states = self._checkpointed_forward(hidden_states,
                                                               attention_mask,
                                                               encoder_output,
                                                               enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1666
                                                               rotary_pos_emb,
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
                                                               is_first_microbatch)
                else:
                    forward_kwargs = {
                        'encoder_output': encoder_output,
                        'enc_dec_attn_mask': enc_dec_attn_mask,
                        'inference_params': inference_params,
                    }

                    if self.transformer_impl == 'transformer_engine':
                        forward_kwargs['is_first_microbatch'] = is_first_microbatch
                        forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
liangjing's avatar
v1  
liangjing committed
1678
1679
1680
1681
1682
1683
1684
                        if self.transformer_engine_v_0_10:
                            forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
                    else:
                        forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
                        forward_kwargs['retriever_input'] = retriever_input
                        forward_kwargs['retriever_output'] = retriever_output
                        forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
1685
1686
1687
1688
1689
1690
1691
1692
1693

                    for index in range(self.num_layers):
                        layer = self._get_layer(index)

                        hidden_states = layer(
                            hidden_states,
                            attention_mask,
                            **forward_kwargs)

liangjing's avatar
v1  
liangjing committed
1694
1695
1696
1697
1698
1699
1700
1701
                        # First Retro decoder layer returns both hidden_states
                        # and retriever_output. Make retriever_output available
                        # to subsequence Retro layers.
                        if isinstance(hidden_states, tuple):
                            assert len(hidden_states) == 2
                            hidden_states, retriever_output = hidden_states
                            forward_kwargs["retriever_output"] = retriever_output

1702
1703
1704
                # Skip counter update for eval and activation checkpointing
                if torch.is_grad_enabled() and self.training:
                    self.microbatch_count += 1
mshoeybi's avatar
mshoeybi committed
1705

1706
        # Final layer norm.
1707
        if self.post_process and self.post_layer_norm:
1708
1709
            hidden_states = self.final_layernorm(hidden_states)

1710
        return hidden_states