transformer.py 81.1 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
3

"""Transformer."""
liangjing's avatar
v1  
liangjing committed
4
import math
xingjinliang's avatar
xingjinliang committed
5
6
7
8
import os
from contextlib import nullcontext
from typing import Optional

liangjing's avatar
v1  
liangjing committed
9
import numpy as np
10
import torch
11
import torch.nn.functional as F
12

xingjinliang's avatar
xingjinliang committed
13
from megatron import core
14
from megatron.core import mpu, tensor_parallel
15
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from megatron.legacy.model.enums import AttnMaskType, LayerType, AttnType
from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
from megatron.core.jit import jit_fuser
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.parallel_state import (
    get_expert_tensor_and_model_parallel_group,
    get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import (
    gather_from_sequence_parallel_region,
    reduce_scatter_to_sequence_parallel_region,
    get_cuda_rng_tracker,
    get_data_parallel_rng_tracker_name,
)
from megatron.legacy.model.enums import AttnMaskType, AttnType, LayerType
from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl
from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.legacy.model.utils import (
    attention_mask_func,
    erf_gelu,
    get_norm,
    openai_gelu,
)
from megatron.training import get_args, get_timers

wxj's avatar
wxj committed
43
44
45
import torch._dynamo
torch._dynamo.config.suppress_errors = True

xingjinliang's avatar
xingjinliang committed
46
from .module import MegatronModule
47

48
49
50
51
52
53
54
55
try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
liangjing's avatar
v1  
liangjing committed
56
    try:
xingjinliang's avatar
xingjinliang committed
57
58
59
        from flash_attn.flash_attn_interface import (
            flash_attn_varlen_func as flash_attn_unpadded_func,
        )
liangjing's avatar
v1  
liangjing committed
60
61
    except ImportError:
        flash_attn_unpadded_func = None
62

wxj's avatar
wxj committed
63
64
65
66
try:
    from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
    flash_attn_func = None
67
68
69
70
71
72
73
74
75
76
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
77
    Transformer takes input of size [s, b, h] and returns a
78
79
80
81
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

82
class DropPath(MegatronModule):
83
    """Drop paths (Stochastic Depth) per sample
84
85
86
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
87
    def __init__(self, drop_prob=0.):
88
89
90
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
91
    def forward(self, hidden_state):
92
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
93
            return hidden_state
94
95
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
96
97
        # hidden_state: [s, b, h]
        shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
98
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
99
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
100
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
101
        output = hidden_state.div(keep_prob) * random_tensor
102
103
        return output

104
105
106
107
108
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
109
    state back into h hidden dimension.
110
111
    """

xingjinliang's avatar
xingjinliang committed
112
    def __init__(self, config, is_expert=False):
113
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
114
        args = get_args()
115

liangjing's avatar
v1  
liangjing committed
116
117
118
119
120
        self.add_bias = config.add_bias_linear

        ffn_hidden_size = config.ffn_hidden_size
        if config.gated_linear_unit:
            ffn_hidden_size *= 2
121

122
        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
123
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
liangjing's avatar
v1  
liangjing committed
124
125
126
127
            config.hidden_size,
            ffn_hidden_size,
            config=config,
            init_method=config.init_method,
128
            bias=self.add_bias,
129
            gather_output=False,
130
            skip_bias_add=True,
xingjinliang's avatar
xingjinliang committed
131
            is_expert=is_expert,
liangjing's avatar
v1  
liangjing committed
132
        )
133

134
135
136
137
        self.bias_gelu_fusion = False
        self.activation_func = None
        self.swiglu = args.swiglu

138
139
140
141
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
142
        elif args.swiglu:
wxj's avatar
wxj committed
143
            @torch.compile(mode="max-autotune-no-cudagraphs")
144
145
146
147
148
149
150
151
152
153
154
            def swiglu(x):
                x = torch.chunk(x, 2, dim=-1)
                return F.silu(x[0]) * x[1]
            self.activation_func = swiglu
        elif args.squared_relu:
            def squared_relu(x):
                return torch.pow(F.relu(x), 2)
            self.activation_func = squared_relu
        else:
            self.bias_gelu_fusion = args.bias_gelu_fusion
            self.activation_func = F.gelu
155
156

        # Project back to h.
157
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
liangjing's avatar
v1  
liangjing committed
158
159
160
161
            config.ffn_hidden_size,
            config.hidden_size,
            config=config,
            init_method=config.output_layer_init_method,
162
            bias=self.add_bias,
xingjinliang's avatar
xingjinliang committed
163
164
165
            skip_bias_add=True,
            input_is_parallel=True,
            is_expert=is_expert,
liangjing's avatar
v1  
liangjing committed
166
        )
167

168
    # @torch.compile(mode="max-autotune-no-cudagraphs")
169
170
    def forward(self, hidden_states):

171
172
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
173

174
        if self.bias_gelu_fusion:
175
176
177
            assert self.add_bias is True
            assert self.activation_func == F.gelu
            intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
178
        else:
Jared Casper's avatar
Jared Casper committed
179
            if bias_parallel is not None:
180
181
                intermediate_parallel = intermediate_parallel + bias_parallel
            intermediate_parallel = self.activation_func(intermediate_parallel)
182
183
184
185

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
186

xingjinliang's avatar
xingjinliang committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def sinkhorn(cost, tol=0.0001):
    cost = torch.exp(cost)
    d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
    d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)

    eps = 0.00000001
    error = 1e9
    d1_old = d1
    while error > tol:
        d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps)
        d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps)
        error = torch.mean(torch.abs(d1_old-d1))
        d1_old = d1
    return d1*cost*d0.unsqueeze(1)


def get_router_linear_layer(config):
    args = get_args()
    router = torch.nn.Linear(args.hidden_size, args.num_experts, bias=False)
    with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
        config.init_method(router.weight)
    setattr(router.weight, 'sequence_parallel',config.sequence_parallel)
    return router


rprenger's avatar
rprenger committed
212
213
214
215
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
liangjing's avatar
v1  
liangjing committed
216
    def __init__(self, config):
rprenger's avatar
rprenger committed
217
218
        super(SwitchMLP, self).__init__()
        args = get_args()
xingjinliang's avatar
xingjinliang committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        self.router = get_router_linear_layer(config)
        self.expert_parallel_size = mpu.get_expert_model_parallel_world_size()
        self.sequence_parallel = config.sequence_parallel
        self.add_bias = config.add_bias_linear

        assert args.num_experts % self.expert_parallel_size == 0
        self.num_local_experts = args.num_experts // self.expert_parallel_size
        local_expert_indices_offset = mpu.get_expert_model_parallel_rank() * self.num_local_experts
        self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)]

        self.local_experts = torch.nn.ModuleList()
        for i in range(self.num_local_experts):
            self.local_experts.append(ParallelMLP(config, is_expert=True))

        self.tp_ep_group = get_expert_tensor_and_model_parallel_group()

    def gather_indices(self, local_indices):
        """ Gather tensors and concatinate along the first dimension."""
        world_size = torch.distributed.get_world_size(group=self.tp_ep_group)
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return local_indices

        dim_size = list(local_indices.size())
        dim_size[0] = dim_size[0] * world_size

        # TODO pre allocate memory
        output = torch.empty(dim_size, dtype=local_indices.dtype,
                             device=torch.cuda.current_device())
        torch.distributed._all_gather_base(
            output, local_indices.contiguous(), group=self.tp_ep_group
        )
        return output
252

rprenger's avatar
rprenger committed
253
    def forward(self, hidden_states):
xingjinliang's avatar
xingjinliang committed
254
255
        # hidden_states: [b, s, h]
        args = get_args()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
256
257
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
258
        h = hidden_states.size(2)
xingjinliang's avatar
xingjinliang committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        route = self.router(hidden_states).view(-1, args.num_experts)

        # TODO (rprenger) Right now we're just using the sinkhorn algorithm
        # for load balancing. There should be an option to do no load balancing
        # and the algorithm and parametets should be further tested
        if self.training:
            with torch.no_grad():
                sinkroute = sinkhorn(route.detach().to(dtype=torch.float32))
                _, max_ind = torch.max(sinkroute, dim=1)
            route = torch.sigmoid(route)
            max_prob = route[torch.arange(route.size(0)), max_ind]
        else:
            route = torch.sigmoid(route)
            max_prob, max_ind = torch.max(route, dim=1)

        max_prob = torch.unsqueeze(max_prob, 1)
        hidden_states = hidden_states.view(-1, hidden_states.size(2))
276

rprenger's avatar
rprenger committed
277
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
278
        # Converting [s, b, h] to [s*b, h].
279
        # Each vector could be routed differently
xingjinliang's avatar
xingjinliang committed
280
281
282
283
284
285
286
        if self.sequence_parallel or (self.expert_parallel_size > 1):
            global_hidden_states = \
                gather_from_sequence_parallel_region(hidden_states, group=self.tp_ep_group)
            global_indices = self.gather_indices(max_ind)
        else:
            global_hidden_states = hidden_states
            global_indices = max_ind
rprenger's avatar
rprenger committed
287

xingjinliang's avatar
xingjinliang committed
288
289
290
        output_total = torch.zeros_like(global_hidden_states)
        if self.add_bias:
            output_bias_total = torch.zeros_like(global_hidden_states)
291

xingjinliang's avatar
xingjinliang committed
292
293
294
295
        for expert_num, expert in enumerate(self.local_experts):
            local_expert_index = self.local_expert_indices[expert_num]
            local_indices = (global_indices == local_expert_index).nonzero()
            hidden = global_hidden_states[local_indices, :]
rprenger's avatar
rprenger committed
296
            output, output_bias = expert(hidden)
xingjinliang's avatar
xingjinliang committed
297
298
            output_total[local_indices, :] = output
            if self.add_bias:
liangjing's avatar
v1  
liangjing committed
299
                output_bias = output_bias.expand_as(output)
xingjinliang's avatar
xingjinliang committed
300
301
302
303
304
305
306
307
308
309
310
311
312
                output_bias_total[local_indices, :] = output_bias

        if self.sequence_parallel or (self.expert_parallel_size > 1):
            output_total = \
                reduce_scatter_to_sequence_parallel_region(output_total, group=self.tp_ep_group)
            if self.add_bias:
                output_bias_total = \
                    reduce_scatter_to_sequence_parallel_region(output_bias_total, group=self.tp_ep_group)

                # bias is duplicated across tensor parallelism ranks;
                # reduce scatter reduces bias across tensor parallel_ranks
                output_bias_total = \
                    output_bias_total/mpu.get_tensor_model_parallel_world_size()
313

rprenger's avatar
rprenger committed
314
        output_total = output_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
315
        output_total = output_total.view(s, b, h)
xingjinliang's avatar
xingjinliang committed
316
        if self.add_bias:
liangjing's avatar
v1  
liangjing committed
317
318
319
320
            output_bias_total = output_bias_total*max_prob
            output_bias_total = output_bias_total.view(s, b, h)
        else:
            output_bias_total = None
rprenger's avatar
rprenger committed
321
322

        return output_total, output_bias_total
323

324
325

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
326

liangjing's avatar
v1  
liangjing committed
327
    def __init__(self, layer_number, config,
328
329
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
liangjing's avatar
v1  
liangjing committed
330
331
        self.fp16 = config.fp16
        self.bf16 = config.bf16
332

liangjing's avatar
v1  
liangjing committed
333
334
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
335
336
337
338
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
liangjing's avatar
v1  
liangjing committed
339
        self.sequence_parallel = config.sequence_parallel
340

liangjing's avatar
v1  
liangjing committed
341
        projection_size = config.kv_channels * config.num_attention_heads
342
343

        # Per attention head and per partition values.
344
        world_size = mpu.get_tensor_model_parallel_world_size()
345
346
347
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
348
            projection_size, config.num_attention_heads)
349
        self.num_attention_heads_per_partition = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
350
            config.num_attention_heads, world_size)
351
352
353
354
355
356
357
358
359
360

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
liangjing's avatar
v1  
liangjing committed
361
            config.masked_softmax_fusion,
362
363
364
365
366
367
368
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
liangjing's avatar
v1  
liangjing committed
369
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
liangjing's avatar
v1  
liangjing committed
385
386
        query_layer = query_layer.reshape(output_size[2],
                                          output_size[0] * output_size[1], -1)
387
388
389
390
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
391
        # preallocting input tensor: [b * np, sq, sk]
392
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
393
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
394
            query_layer.dtype, "mpu")
395
396
397

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
398
            matmul_input_buffer,
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
416
        if not self.sequence_parallel:
417
            with tensor_parallel.get_cuda_rng_tracker().fork():
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer

silencealiang's avatar
add  
silencealiang committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
class FlashSelfAttentionTorch(torch.nn.Module):
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
        super().__init__()
        assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.attention_dropout = attention_dropout
    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda
        if os.environ.get('USE_BSHD',None):
            q, k, v = [rearrange(x, 's b h d -> b s h d').contiguous()
                       for x in (q, k, v)]
        else:
            q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
                       for x in (q, k, v)]
        output = SDPA(q, k, v, is_causal=self.causal, dropout_p=self.attention_dropout, scale=self.softmax_scale)
        if os.environ.get('USE_BSHD',None):
            output = rearrange(output, 'b s h d -> s b (h d)').contiguous()
        else:
            output = rearrange(output, 'b h s d -> s b (h d)').contiguous()
        return output
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
class FlashSelfAttention(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
                                                      'e.g., with pip install flash-attn')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

wxj's avatar
wxj committed
508
509
510
511
        # Use FlashAttention-2 when args.use_flash_attn_ck is True
        args = get_args()
        self.flash_attn_func = flash_attn_unpadded_func

512
513
514
515
516
517
    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
Jimmy Zhang's avatar
Jimmy Zhang committed
518
519
520

        assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
        assert all((i.is_cuda for i in (q,k,v)))
Jimmy Zhang's avatar
Jimmy Zhang committed
521
522

        batch_size, seqlen_q = q.shape[0], q.shape[1]
Jimmy Zhang's avatar
Jimmy Zhang committed
523
        seqlen_k = k.shape[1]
Jimmy Zhang's avatar
Jimmy Zhang committed
524

Jimmy Zhang's avatar
Jimmy Zhang committed
525
526
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
Jimmy Zhang's avatar
Jimmy Zhang committed
527
528
                                    device=q.device)

Jimmy Zhang's avatar
Jimmy Zhang committed
529
530
531
532
533
534
        if self.training:
            # during training q,k,v always have same seqlen
            assert seqlen_k == seqlen_q

            is_causal = self.causal
            cu_seqlens_k = cu_seqlens_q
liangjing's avatar
v1  
liangjing committed
535
            dropout_p = self.dropout_p
Jimmy Zhang's avatar
Jimmy Zhang committed
536
        else:
Jimmy Zhang's avatar
Jimmy Zhang committed
537
            # turn off FA causal mask after first inference autoregressive iteration
Jimmy Zhang's avatar
Jimmy Zhang committed
538
            # only on first autoregressive step q,k,v have same seqlen
Jimmy Zhang's avatar
Jimmy Zhang committed
539
540
            is_causal = seqlen_q == seqlen_k
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
Jimmy Zhang's avatar
Jimmy Zhang committed
541
                        device=q.device)
liangjing's avatar
v1  
liangjing committed
542
            dropout_p = 0
Jimmy Zhang's avatar
Jimmy Zhang committed
543

Jimmy Zhang's avatar
Jimmy Zhang committed
544
545
        output = flash_attn_unpadded_func(
            q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
liangjing's avatar
v1  
liangjing committed
546
            dropout_p,
Jimmy Zhang's avatar
Jimmy Zhang committed
547
548
            softmax_scale=self.softmax_scale, causal=is_causal
        )
Jimmy Zhang's avatar
Jimmy Zhang committed
549

550
551
552
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output

wxj's avatar
wxj committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
class FlashSelfAttentionTriton(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda
        q, k, v = [rearrange(x, 's b h d -> b h s d').contiguous()
                       for x in (q, k, v)]
        output = flash_attn_func(q, k, v, self.causal)
        output = rearrange(output, 'b s h d -> h b (s d)').contiguous()
        return output
585

586
class ParallelAttention(MegatronModule):
587
588
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
589
    Self-attention layer takes input with size [s, b, h]
590
591
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
592

liangjing's avatar
v1  
liangjing committed
593
    def __init__(self, config, layer_number,
594
595
596
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
597
        args = get_args()
598
        self.layer_number = max(1, layer_number)
599
600
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
liangjing's avatar
v1  
liangjing committed
601
602
        self.params_dtype = config.params_dtype
        self.sequence_parallel = config.sequence_parallel
xingjinliang's avatar
xingjinliang committed
603
        self.config = config
liangjing's avatar
v1  
liangjing committed
604
605
606
607
608
609
610
611
        self.group_query_attention = args.group_query_attention
        self.num_query_groups = args.num_query_groups

        query_projection_size = config.kv_channels * config.num_attention_heads
        if self.group_query_attention:
            kv_projection_size = args.kv_channels * args.num_query_groups
        else:
            kv_projection_size = args.kv_channels * args.num_attention_heads
612

silencealiang's avatar
add  
silencealiang committed
613
        self.use_flash_attn = (args.use_flash_attn_cutlass or args.use_flash_attn_triton or args.use_flash_attn_torch) \
liangjing's avatar
v1  
liangjing committed
614
615
            and attention_type == AttnType.self_attn \
            and self.attn_mask_type == AttnMaskType.causal
wxj's avatar
wxj committed
616
        self.use_flash_attn_triton = args.use_flash_attn_triton
silencealiang's avatar
add  
silencealiang committed
617
        self.use_flash_attn_torch = args.use_flash_attn_torch
wxj's avatar
wxj committed
618

619
        if self.use_flash_attn:
wxj's avatar
wxj committed
620
            if args.use_flash_attn_cutlass:
wxj's avatar
wxj committed
621
622
623
624
625
626
                if flash_attn_unpadded_func is None:
                    raise ImportError('FlashAttention is not installed, please install with '
                                    'pip install flash-attn')
            if args.use_flash_attn_triton:
                assert flash_attn_func != None, "Cannot import FlashAttention triton "

627
628
629
630
631
632
            assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
                                                          'self-attention for now')
            assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
                                                                'supports causal mask for now')
            if rearrange is None:
                raise ImportError('einops is not installed, please install with pip install einops')
633

634
        # Per attention head and per partition values.
635
        world_size = mpu.get_tensor_model_parallel_world_size()
636
        self.hidden_size_per_attention_head = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
637
            query_projection_size, config.num_attention_heads)
638
        self.num_attention_heads_per_partition = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
639
640
641
642
643
644
645
646
647
648
            config.num_attention_heads, world_size)

        if self.group_query_attention:
            if args.num_query_groups % world_size != 0:
                raise NotImplementedError('Currently the num_query_groups should be '
                                          'a multiple of the tensor parallel size')
            self.num_query_groups_per_partition = core.utils.divide(
                        args.num_query_groups, world_size)
        else:
            self.num_query_groups_per_partition = self.num_attention_heads_per_partition
649
650

        # Strided linear layer.
651
        if attention_type == AttnType.self_attn:
652
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
liangjing's avatar
v1  
liangjing committed
653
654
655
656
                config.hidden_size,
                query_projection_size + 2 * kv_projection_size,
                config=config,
                init_method=config.init_method,
xingjinliang's avatar
xingjinliang committed
657
                bias=args.add_bias_linear or args.add_qkv_bias,
liangjing's avatar
v1  
liangjing committed
658
                gather_output=False)
659
660
661
        else:
            assert attention_type == AttnType.cross_attn

liangjing's avatar
v1  
liangjing committed
662
663
664
            if self.group_query_attention:
                raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
            assert query_projection_size == kv_projection_size
665

liangjing's avatar
v1  
liangjing committed
666
667
668
669
670
671
672
            self.query = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                query_projection_size,
                config=config,
                init_method=config.init_method,
                bias=config.add_bias_linear,
                gather_output=False)
673

liangjing's avatar
v1  
liangjing committed
674
675
676
677
678
679
680
681
682
            self.key_value = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                2 * kv_projection_size,
                config=config,
                init_method=config.init_method,
                bias=config.add_bias_linear,
                gather_output=False)

        self.core_attention = CoreAttention(self.layer_number, config,
683
                                            self.attn_mask_type)
liangjing's avatar
v1  
liangjing committed
684
        self.checkpoint_core_attention = config.recompute_granularity == 'selective'
685

wxj's avatar
wxj committed
686
687
688
689
        if self.use_flash_attn_triton:
            self.core_attention_flash = FlashSelfAttentionTriton(
                causal=True, attention_dropout=args.attention_dropout
            )
silencealiang's avatar
add  
silencealiang committed
690
691
        elif self.use_flash_attn_torch:
            self.core_attention_flash = FlashSelfAttentionTorch(causal=True, attention_dropout=config.attention_dropout)
wxj's avatar
wxj committed
692
        elif self.use_flash_attn:
693
            self.core_attention_flash = FlashSelfAttention(
liangjing's avatar
v1  
liangjing committed
694
                causal=True, attention_dropout=config.attention_dropout
695
696
            )

697
        # Output.
698
        self.dense = tensor_parallel.RowParallelLinear(
liangjing's avatar
v1  
liangjing committed
699
700
701
702
            query_projection_size,
            config.hidden_size,
            config=config,
            init_method=config.output_layer_init_method,
703
            bias=args.add_bias_linear,
704
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
705
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
706

707
    def _checkpointed_attention_forward(self, query_layer, key_layer,
Mostofa Patwary's avatar
Mostofa Patwary committed
708
709
                                        value_layer, attention_mask,
                                        rotary_pos_emb=None):
710
711
712
713
714
715
716
717
718
719
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

Mostofa Patwary's avatar
Mostofa Patwary committed
720
721
722
        q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
            else rotary_pos_emb

723
        hidden_states = tensor_parallel.checkpoint(
724
            custom_forward,
Mostofa Patwary's avatar
Mostofa Patwary committed
725
726
            False, query_layer, key_layer, value_layer, attention_mask,
            q_pos_emb, k_pos_emb)
727
728

        return hidden_states
729

liangjing's avatar
v1  
liangjing committed
730
    def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
731
732
733
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
liangjing's avatar
v1  
liangjing committed
734
            num_attention_heads,
735
736
737
738
739
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
740
741
                encoder_output=None, inference_params=None,
                rotary_pos_emb=None):
742
        # hidden_states: [sq, b, h]
743

744
745
746
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
Mostofa Patwary's avatar
Mostofa Patwary committed
747
        is_first_step = False
mshoeybi's avatar
mshoeybi committed
748
        if inference_params:
749
            if self.layer_number not in inference_params.key_value_memory_dict:
liangjing's avatar
v1  
liangjing committed
750
                inf_max_seq_len = inference_params.max_sequence_length
mshoeybi's avatar
mshoeybi committed
751
                inf_max_batch_size = inference_params.max_batch_size
752
                inference_key_memory = self._allocate_memory(
liangjing's avatar
v1  
liangjing committed
753
754
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)
755
                inference_value_memory = self._allocate_memory(
liangjing's avatar
v1  
liangjing committed
756
757
758
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)

759
760
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
Mostofa Patwary's avatar
Mostofa Patwary committed
761
                is_first_step = True
762
763
764
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
765

766
767
768
        # =====================
        # Query, Key, and Value
        # =====================
769
        if self.attention_type == AttnType.self_attn:
xingjinliang's avatar
xingjinliang committed
770

liangjing's avatar
v1  
liangjing committed
771
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
772
773
            mixed_x_layer, _ = self.query_key_value(hidden_states)

liangjing's avatar
v1  
liangjing committed
774
775
776
777
778
779
780
781
            # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_query_groups_per_partition,
                (
                    (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
                    * self.hidden_size_per_attention_head
                ),
            )
782
783
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

liangjing's avatar
v1  
liangjing committed
784
            # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
785
            (query_layer,
liangjing's avatar
v1  
liangjing committed
786
787
788
789
790
791
792
793
794
795
796
797
            key_layer,
            value_layer) = torch.split(
                mixed_x_layer,
                [
                    (
                        self.num_attention_heads_per_partition // self.num_query_groups_per_partition
                        * self.hidden_size_per_attention_head
                    ),
                    self.hidden_size_per_attention_head,
                    self.hidden_size_per_attention_head
                ],
                dim=3)
xingjinliang's avatar
xingjinliang committed
798

liangjing's avatar
v1  
liangjing committed
799
            # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
wxj's avatar
wxj committed
800
            query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
801
802
803
804
805
806
807
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
liangjing's avatar
v1  
liangjing committed
808
                2 * self.hidden_size_per_attention_head)
809
810
811
812
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
liangjing's avatar
v1  
liangjing committed
813
            value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
814
815
816
817
818
819

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
liangjing's avatar
v1  
liangjing committed
820
                self.hidden_size_per_attention_head)
821
            query_layer = query_layer.view(*new_tensor_shape)
822

mshoeybi's avatar
mshoeybi committed
823
824
825
        # ==================================
        # Adjust key and value for inference
        # ==================================
826

Mostofa Patwary's avatar
Mostofa Patwary committed
827
828
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
829
830
831
832
            if isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = rotary_pos_emb
            else:
                rotary_pos_emb = ((rotary_pos_emb,) * 2)
Mostofa Patwary's avatar
Mostofa Patwary committed
833

mshoeybi's avatar
mshoeybi committed
834
        if inference_params:
mshoeybi's avatar
mshoeybi committed
835
836
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
837
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
838
839
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
840
            assert sequence_end <= inference_key_memory.size(0)
841
            # Copy key and values.
842
843
844
845
846
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
847
                :sequence_end, batch_start:batch_end, ...]
848
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
849
                :sequence_end, batch_start:batch_end, ...]
850

Mostofa Patwary's avatar
Mostofa Patwary committed
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871

            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)

872
873
874
        # ==================================
        # core attention computation
        # ==================================
875

liangjing's avatar
v1  
liangjing committed
876
        # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
xingjinliang's avatar
xingjinliang committed
877
878
879
880
881
882
883
884
885
        if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
            key_layer = key_layer.repeat_interleave(
                self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
                dim = 2
            )
            value_layer = value_layer.repeat_interleave(
                self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
                dim = 2
            )
liangjing's avatar
v1  
liangjing committed
886

Mostofa Patwary's avatar
Mostofa Patwary committed
887
888
889
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
xingjinliang's avatar
xingjinliang committed
890
891
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb,self.config)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb,self.config)
Mostofa Patwary's avatar
Mostofa Patwary committed
892
893
894
895
896
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

897
898
899
900
901
902
903
        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
904
        else:
silencealiang's avatar
add  
silencealiang committed
905
            if not self.use_flash_attn_triton and not self.use_flash_attn_torch:
wxj's avatar
wxj committed
906
                query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
907
                       for x in (query_layer, key_layer, value_layer)]
wxj's avatar
wxj committed
908
            
909
910
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
wxj's avatar
wxj committed
911
                    context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
912
            else:
wxj's avatar
wxj committed
913
914
                context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
            
silencealiang's avatar
add  
silencealiang committed
915
            if not self.use_flash_attn_triton and not self.use_flash_attn_torch:
wxj's avatar
wxj committed
916
                context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
917
918

        # =================
919
        # Output. [sq, b, h]
920
921
922
        # =================

        output, bias = self.dense(context_layer)
923

924
925
926
        return output, bias


927
def bias_dropout_add(x, bias, residual, prob, training):
Jared Casper's avatar
Jared Casper committed
928
    # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
929
930
931
    if bias is not None:
        x = x + bias
    out = torch.nn.functional.dropout(x, p=prob, training=training)
932
933
934
935
936
937
938
939
940
941
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


xingjinliang's avatar
xingjinliang committed
942
@jit_fuser
943
def bias_dropout_add_fused_train(x: torch.Tensor,
Jared Casper's avatar
Jared Casper committed
944
                                 bias: Optional[torch.Tensor],
945
946
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
947
948
949
    return bias_dropout_add(x, bias, residual, prob, True)


xingjinliang's avatar
xingjinliang committed
950
@jit_fuser
951
def bias_dropout_add_fused_inference(x: torch.Tensor,
Jared Casper's avatar
Jared Casper committed
952
                                     bias: Optional[torch.Tensor],
953
954
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
955
    return bias_dropout_add(x, bias, residual, prob, False)
956
957
958
959
960


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
961
    Transformer layer takes input with size [s, b, h] and returns an
962
963
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
964

liangjing's avatar
v1  
liangjing committed
965
    def __init__(self, config,
966
                 layer_number, layer_type=LayerType.encoder,
967
968
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
969
        args = get_args()
970
971

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
972
        self.layer_number = layer_number
973
        self.layer_type = layer_type
974

xingjinliang's avatar
xingjinliang committed
975
        self.apply_residual_connection_post_norm \
liangjing's avatar
v1  
liangjing committed
976
            = config.apply_residual_connection_post_layernorm
977

liangjing's avatar
v1  
liangjing committed
978
979
        self.bf16 = config.bf16
        self.fp32_residual_connection = config.fp32_residual_connection
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
980

xingjinliang's avatar
xingjinliang committed
981
982
        # Normalize the input data.
        self.input_norm = get_norm(config)
983
984

        # Self attention.
985
        self.self_attention = ParallelAttention(
liangjing's avatar
v1  
liangjing committed
986
            config,
987
988
989
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
liangjing's avatar
v1  
liangjing committed
990
991
        self.hidden_dropout = config.hidden_dropout
        self.bias_dropout_fusion = config.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
992
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
993

xingjinliang's avatar
xingjinliang committed
994
995
        # Normalize the attention output
        self.post_attention_norm = get_norm(config)
996

liangjing's avatar
v1  
liangjing committed
997
998
999
1000
1001
        # Cross attention.
        if self.layer_type in (LayerType.decoder,
                               LayerType.retro_decoder,
                               LayerType.retro_decoder_with_retriever,
                               LayerType.retro_encoder):
1002
            self.inter_attention = ParallelAttention(
liangjing's avatar
v1  
liangjing committed
1003
                config,
1004
1005
                layer_number,
                attention_type=AttnType.cross_attn)
xingjinliang's avatar
xingjinliang committed
1006
1007
            # Normalize the attention output.
            self.post_inter_attention_norm = get_norm(config)
1008

1009
        # MLP
rprenger's avatar
rprenger committed
1010
        if args.num_experts is not None:
liangjing's avatar
v1  
liangjing committed
1011
            self.mlp = SwitchMLP(config)
rprenger's avatar
rprenger committed
1012
        else:
liangjing's avatar
v1  
liangjing committed
1013
            self.mlp = ParallelMLP(config)
1014

1015
1016
1017
1018
1019
1020
1021
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

liangjing's avatar
v1  
liangjing committed
1022
1023
        if args.retro_add_retriever:
            self.retro_num_neighbors = args.retro_num_neighbors
xingjinliang's avatar
xingjinliang committed
1024
1025
1026
            self.retro_chunk_length = args.retro_chunk_length
            self.retro_retrieved_length = \
                args.retro_num_retrieved_chunks * args.retro_chunk_length
liangjing's avatar
v1  
liangjing committed
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043

        # Retriever (bi-directional transformer with cross attention)
        if layer_type == LayerType.retro_decoder_with_retriever:
            self.retriever = ParallelTransformer(
                config=config,
                model_type=ModelType.retro_encoder,
                self_attn_mask_type=AttnMaskType.padding,
                pre_process=True,
                post_process=False,
            )
            self._retriever_key = 'retriever'
        else:
            self.retriever = None

    def default_decoder_cross_attention(self,
                                        encoder_output,
                                        enc_dec_attn_mask,
xingjinliang's avatar
xingjinliang committed
1044
1045
                                        norm_input,
                                        norm_output,
liangjing's avatar
v1  
liangjing committed
1046
1047
1048
1049
1050
                                        bias_dropout_add_func):
        '''Cross attention for a standard encoder-decoder model.'''

        # Attention.
        attention_output, attention_bias = \
xingjinliang's avatar
xingjinliang committed
1051
            self.inter_attention(norm_output,
liangjing's avatar
v1  
liangjing committed
1052
1053
1054
1055
                                 enc_dec_attn_mask,
                                 encoder_output=encoder_output)

        # Residual connection.
xingjinliang's avatar
xingjinliang committed
1056
1057
        if self.apply_residual_connection_post_norm:
            residual = norm_output
liangjing's avatar
v1  
liangjing committed
1058
        else:
xingjinliang's avatar
xingjinliang committed
1059
            residual = norm_input
liangjing's avatar
v1  
liangjing committed
1060
1061
1062
1063
1064
1065

        if attention_bias is not None:
            attention_bias = attention_bias.expand_as(residual)

        # Bias-dropout-add.
        with self.bias_dropout_add_exec_handler():
xingjinliang's avatar
xingjinliang committed
1066
            norm_input = bias_dropout_add_func(
liangjing's avatar
v1  
liangjing committed
1067
1068
1069
1070
1071
                attention_output,
                attention_bias,
                residual,
                self.hidden_dropout)

xingjinliang's avatar
xingjinliang committed
1072
1073
        # Normalize.
        norm_output = self.post_inter_attention_norm(norm_input)
liangjing's avatar
v1  
liangjing committed
1074

xingjinliang's avatar
xingjinliang committed
1075
        return norm_input, norm_output
liangjing's avatar
v1  
liangjing committed
1076
1077
1078

    def retro_encoder_cross_attention(self,
                                      retriever_output,
xingjinliang's avatar
xingjinliang committed
1079
1080
                                      norm_input,
                                      norm_output,
liangjing's avatar
v1  
liangjing committed
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
                                      bias_dropout_add_func):
        """Cross attention for Retro encoder.

        Notation:
            ns : Sequence length.
            bs : Batch size.
            d  : Hidden size.
            l  : Number of chunks per sample (i.e., seq_length/chunk_length).
            k  : Number of neighbors.
            r  : Number of retrieved tokens (neighbors + continuation).
        """

xingjinliang's avatar
xingjinliang committed
1093
        ns, bs, d = norm_output.shape # [r, bs * l * k, d]
liangjing's avatar
v1  
liangjing committed
1094
1095

        # Divide sequence dimension into chunks.
xingjinliang's avatar
xingjinliang committed
1096
1097
1098
1099
1100
1101
1102
        chunked_outputs = norm_output.reshape(self.retro_retrieved_length,
                                              -1,
                                              self.retro_num_neighbors,
                                              d)
        chunked_outputs_before_norm = \
            norm_input.reshape(self.retro_retrieved_length, -1,
                               self.retro_num_neighbors, d) # [r, bs*l, k, d]
liangjing's avatar
v1  
liangjing committed
1103
1104

        # Per-chunk attention.
xingjinliang's avatar
xingjinliang committed
1105
1106
        norm_inputs = []
        norm_outputs = []
liangjing's avatar
v1  
liangjing committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        for k in range(self.retro_num_neighbors):

            # Attention.
            chunked_output = chunked_outputs[:,:,k].contiguous()
            attention_output, attention_bias = \
                self.inter_attention(
                    chunked_output, # Q (neighbor embedding)
                    None,
                    encoder_output=retriever_output) # K, V (hidden act)

            # Residual connection.
xingjinliang's avatar
xingjinliang committed
1118
            if self.apply_residual_connection_post_norm:
liangjing's avatar
v1  
liangjing committed
1119
1120
                residual = chunked_output
            else:
xingjinliang's avatar
xingjinliang committed
1121
                residual = chunked_outputs_before_norm[:,:,k]
liangjing's avatar
v1  
liangjing committed
1122
1123
1124

            # Re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
xingjinliang's avatar
xingjinliang committed
1125
                norm_input = bias_dropout_add_func(
liangjing's avatar
v1  
liangjing committed
1126
1127
1128
1129
                    attention_output,
                    None if attention_bias is None else attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
xingjinliang's avatar
xingjinliang committed
1130
                norm_inputs.append(norm_input)
liangjing's avatar
v1  
liangjing committed
1131
1132

            # Layer norm.
xingjinliang's avatar
xingjinliang committed
1133
1134
            norm_output = self.post_inter_attention_norm(norm_input)
            norm_outputs.append(norm_output)
liangjing's avatar
v1  
liangjing committed
1135
1136

        # Concatenate layer norms.
xingjinliang's avatar
xingjinliang committed
1137
1138
1139
1140
        # norm_input : [r, k * bs * l, d]
        # norm_output : [r, k * bs * l, d]
        norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d)
        norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d)
liangjing's avatar
v1  
liangjing committed
1141

xingjinliang's avatar
xingjinliang committed
1142
        return norm_input, norm_output
liangjing's avatar
v1  
liangjing committed
1143
1144
1145
1146
1147

    def retro_decoder_cross_attention(self,
                                      retriever_input,
                                      retriever_output,
                                      retriever_attn_mask,
xingjinliang's avatar
xingjinliang committed
1148
1149
                                      norm_input,
                                      norm_output,
liangjing's avatar
v1  
liangjing committed
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
                                      inference_params,
                                      bias_dropout_add_func):
        """Cross attention for Retro decoder.

        Notation:
            ns : Sequence length.
            bs : Batch size.
            d  : Hidden size.
            l  : Number of chunks per sample (i.e., seq_length/chunk_length).
            m  : Number of tokens per chunk.
            k  : Number of neighbors.
            r  : Number of retrieved tokens (neighbors + continuation).
        """

xingjinliang's avatar
xingjinliang committed
1164
        ns, bs, d = norm_output.shape
liangjing's avatar
v1  
liangjing committed
1165
1166
1167
1168
1169
1170
1171
        l = int(np.ceil(ns / self.retro_chunk_length))

        # Retrieve neighbors.
        if self.layer_type == LayerType.retro_decoder_with_retriever:
            first_ns = ns % self.retro_chunk_length
            if first_ns > 0:
                first_chunk, rest_chunk = \
xingjinliang's avatar
xingjinliang committed
1172
                    norm_output[:first_ns], norm_output[first_ns:]
liangjing's avatar
v1  
liangjing committed
1173
1174
1175
1176
1177
1178
1179
1180
                first_chunk = torch.nn.functional.pad(
                    first_chunk,
                    (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
                    'constant',
                    0)
                chunked_output = \
                    torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
            else:
xingjinliang's avatar
xingjinliang committed
1181
                chunked_output = norm_output # [l * m, bs, d]
liangjing's avatar
v1  
liangjing committed
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
            chunked_output = chunked_output \
                .reshape(l, self.retro_chunk_length, bs, d) \
                .permute(1, 2, 0, 3) \
                .reshape(self.retro_chunk_length, bs * l, d) \
                .contiguous()

            # Get Encoder Output
            retriever_output = self.retriever(
                hidden_states=retriever_input,
                attention_mask=retriever_attn_mask,
                retriever_output=chunked_output,
                retriever_attn_mask=retriever_attn_mask,
                inference_params=inference_params) # [r, k * bs * l , d]
            retriever_output = retriever_output.reshape(
                self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]

        # Chunks.
        pad = (ns - 1) % self.retro_chunk_length
xingjinliang's avatar
xingjinliang committed
1200
        attending_chunks = norm_output[pad:]
liangjing's avatar
v1  
liangjing committed
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
        padded_chunks = torch.nn.functional.pad(
            attending_chunks,
            (0, 0, 0, 0, 0, self.retro_chunk_length - 1),
            'constant', 0)
        padded_chunked_output = padded_chunks \
            .reshape(l, self.retro_chunk_length, bs, d) \
            .permute(1, 2, 0, 3)
        padded_chunked_output = padded_chunked_output.reshape(
            self.retro_chunk_length, bs * l, d).contiguous()

        # Encoder output.
        attention_output, attention_bias = \
            self.inter_attention(padded_chunked_output,
                                 None,
                                 encoder_output=retriever_output)

        # Residual connection.
xingjinliang's avatar
xingjinliang committed
1218
1219
        if self.apply_residual_connection_post_norm:
            residual = norm_output
liangjing's avatar
v1  
liangjing committed
1220
        else:
xingjinliang's avatar
xingjinliang committed
1221
            residual = norm_input
liangjing's avatar
v1  
liangjing committed
1222
1223
1224

        # Re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
xingjinliang's avatar
xingjinliang committed
1225
            norm_input = bias_dropout_add_func(
liangjing's avatar
v1  
liangjing committed
1226
1227
1228
1229
                attention_output,
                None if attention_bias is None else attention_bias.expand_as(attention_output),
                torch.zeros_like(attention_output),
                self.hidden_dropout)
xingjinliang's avatar
xingjinliang committed
1230
            norm_input = norm_input \
liangjing's avatar
v1  
liangjing committed
1231
1232
                .reshape(self.retro_chunk_length, bs, l, d) \
                .permute(2, 0, 1, 3) # [l, m, bs, d]
xingjinliang's avatar
xingjinliang committed
1233
1234
1235
            norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d)
            norm_input = torch.nn.functional.pad(
                norm_input,
liangjing's avatar
v1  
liangjing committed
1236
1237
                (0, 0, 0, 0, pad, 0),
                'constant', 0)[:ns] # [ns, b, d]
xingjinliang's avatar
xingjinliang committed
1238
1239
1240
            # TODO: better redesign with inference param
            args = get_args()
            norm_input = args.retro_attention_gate * norm_input + residual
liangjing's avatar
v1  
liangjing committed
1241
1242

        # Layer norm post the decoder attention
xingjinliang's avatar
xingjinliang committed
1243
        norm_output = self.post_inter_attention_norm(norm_input)
liangjing's avatar
v1  
liangjing committed
1244

xingjinliang's avatar
xingjinliang committed
1245
        return retriever_output, norm_input, norm_output
liangjing's avatar
v1  
liangjing committed
1246

1247
    # @torch.compile(mode="max-autotune-no-cudagraphs")
1248
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1249
                encoder_output=None, enc_dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
1250
1251
1252
1253
1254
                retriever_input=None,
                retriever_output=None,
                retriever_attn_mask=None,
                inference_params=None,
                rotary_pos_emb=None):
xingjinliang's avatar
xingjinliang committed
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264

        # Update the params in case the retro param changes during inference
        # TODO: better redesign with inference param
        args = get_args()
        if args.retro_add_retriever:
            self.retro_num_neighbors = args.retro_num_neighbors
            self.retro_chunk_length = args.retro_chunk_length
            self.retro_retrieved_length = \
                args.retro_num_retrieved_chunks * args.retro_chunk_length

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1265
        # hidden_states: [s, b, h]
1266

1267
        # Layer norm at the beginning of the transformer layer.
xingjinliang's avatar
xingjinliang committed
1268
        norm_output = self.input_norm(hidden_states)
liangjing's avatar
v1  
liangjing committed
1269

1270
        # Self attention.
1271
        attention_output, attention_bias = \
1272
            self.self_attention(
xingjinliang's avatar
xingjinliang committed
1273
                norm_output,
1274
                attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1275
                inference_params=inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
1276
                rotary_pos_emb=rotary_pos_emb)
1277

1278
        # Residual connection.
xingjinliang's avatar
xingjinliang committed
1279
1280
        if self.apply_residual_connection_post_norm:
            residual = norm_output
1281
1282
1283
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1284
        if self.drop_path is None:
1285
1286
1287
1288
1289
1290
1291
1292
1293
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
1294
            else:
1295
                bias_dropout_add_func = get_bias_dropout_add(self.training)
1296

1297
1298
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)
1299
            with self.bias_dropout_add_exec_handler():
xingjinliang's avatar
xingjinliang committed
1300
                norm_input = bias_dropout_add_func(
1301
                    attention_output,
1302
                    attention_bias,
1303
1304
1305
1306
1307
1308
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
xingjinliang's avatar
xingjinliang committed
1309
            norm_input = residual + self.drop_path(out)
1310

1311
        # Layer norm post the self attention.
xingjinliang's avatar
xingjinliang committed
1312
        norm_output = self.post_attention_norm(norm_input)
1313

liangjing's avatar
v1  
liangjing committed
1314
1315
1316
1317
        # Cross attention.
        if self.layer_type == LayerType.encoder:
            pass
        elif self.layer_type == LayerType.decoder:
xingjinliang's avatar
xingjinliang committed
1318
            norm_input, norm_output = \
liangjing's avatar
v1  
liangjing committed
1319
1320
1321
                self.default_decoder_cross_attention(
                    encoder_output,
                    enc_dec_attn_mask,
xingjinliang's avatar
xingjinliang committed
1322
1323
                    norm_input,
                    norm_output,
liangjing's avatar
v1  
liangjing committed
1324
1325
                    bias_dropout_add_func)
        elif self.layer_type == LayerType.retro_encoder:
xingjinliang's avatar
xingjinliang committed
1326
            norm_input, norm_output = \
liangjing's avatar
v1  
liangjing committed
1327
1328
                self.retro_encoder_cross_attention(
                    retriever_output,
xingjinliang's avatar
xingjinliang committed
1329
1330
                    norm_input,
                    norm_output,
liangjing's avatar
v1  
liangjing committed
1331
1332
1333
                    bias_dropout_add_func)
        elif self.layer_type in (LayerType.retro_decoder,
                                 LayerType.retro_decoder_with_retriever):
xingjinliang's avatar
xingjinliang committed
1334
            retriever_output, norm_input, norm_output = \
liangjing's avatar
v1  
liangjing committed
1335
1336
1337
1338
                self.retro_decoder_cross_attention(
                    retriever_input,
                    retriever_output,
                    retriever_attn_mask,
xingjinliang's avatar
xingjinliang committed
1339
1340
                    norm_input,
                    norm_output,
liangjing's avatar
v1  
liangjing committed
1341
1342
1343
1344
1345
                    inference_params,
                    bias_dropout_add_func)
        else:
            raise Exception("Unsupported layer type, '%s'." %
                            self.layer_type.name)
1346

1347
        # MLP.
xingjinliang's avatar
xingjinliang committed
1348
        mlp_output, mlp_bias = self.mlp(norm_output)
1349

1350
        # Second residual connection.
xingjinliang's avatar
xingjinliang committed
1351
1352
        if self.apply_residual_connection_post_norm:
            residual = norm_output
1353
        else:
xingjinliang's avatar
xingjinliang committed
1354
            residual = norm_input
1355

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1356
        if self.drop_path is None:
1357
1358
            if mlp_bias is not None:
                mlp_bias = mlp_bias.expand_as(residual)
1359
            with self.bias_dropout_add_exec_handler():
1360
1361
                output = bias_dropout_add_func(
                    mlp_output,
1362
                    mlp_bias,
1363
1364
                    residual,
                    self.hidden_dropout)
1365
1366
1367
1368
1369
1370
1371

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
1372
1373
1374
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
1375

1376
        else:
1377
1378
1379
            if mlp_bias is not None:
                mlp_output = mlp_output + mlp_bias
            out = torch.nn.functional.dropout(mlp_output,
1380
1381
1382
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
1383

liangjing's avatar
v1  
liangjing committed
1384
1385
1386
1387
        if self.layer_type == LayerType.retro_decoder_with_retriever:
            return output, retriever_output
        else:
            return output
1388
1389


1390
1391
1392
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
1393
    The sole purpose of this layer is for when a standalone embedding layer
1394
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


liangjing's avatar
v1  
liangjing committed
1416
def _get_num_layers(args, model_type, is_decoder=False):
1417
    """Compute the number of transformer layers resident on the current rank."""
liangjing's avatar
v1  
liangjing committed
1418
1419
1420
1421
    is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
    if model_type == ModelType.retro_encoder:
        num_layers = args.retro_encoder_layers
    elif mpu.get_pipeline_model_parallel_world_size() > 1:
xingjinliang's avatar
xingjinliang committed
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
        assert not is_encoder_and_decoder_model, "This is no longer supported."
        assert args.num_layers == args.encoder_num_layers
        assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
            'num_layers must be divisible by transformer_pipeline_model_parallel_size'

        # When a standalone embedding stage is used, all transformer layers
        # are divided among pipeline rank >= 1, while on pipeline rank 0,
        # ranks either contain the input embedding layer (virtual pp rank 0),
        # or no layers at all (virtual pp rank >= 1).
        num_layers = (
            0
            if args.standalone_embedding_stage
            and mpu.get_pipeline_model_parallel_rank() == 0 else
            args.num_layers // args.transformer_pipeline_model_parallel_size
        )
1437
    else:
Jared Casper's avatar
Jared Casper committed
1438
1439
1440
1441
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
1442
1443
1444
    return num_layers


liangjing's avatar
v1  
liangjing committed
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
                    layer_number):
    args = get_args()
    if args.retro_add_retriever and layer_number in retro_layer_numbers:
        if model_type == ModelType.retro_decoder:
            return LayerType.retro_decoder_with_retriever \
                if layer_number == retro_layer_numbers[0] \
                   else LayerType.retro_decoder
        elif model_type == ModelType.retro_encoder:
            return LayerType.retro_encoder
        else:
            raise Exception("Unsupported model type, '%s'." % model_type)
    else:
        return default_layer_type


1461
1462
1463
class ParallelTransformer(MegatronModule):
    """Transformer class."""

liangjing's avatar
v1  
liangjing committed
1464
1465
    def __init__(self, config,
                 model_type, layer_type=LayerType.encoder,
1466
                 self_attn_mask_type=AttnMaskType.padding,
xingjinliang's avatar
xingjinliang committed
1467
                 post_norm=True,
liangjing's avatar
v1  
liangjing committed
1468
1469
                 pre_process=True,
                 post_process=True,
1470
                 drop_path_rate=0.0):
1471
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
1472
        args = get_args()
1473

1474
        self.layer_type = layer_type
liangjing's avatar
v1  
liangjing committed
1475
1476
1477
        self.model_type = model_type
        self.bf16 = config.bf16
        self.fp32_residual_connection = config.fp32_residual_connection
xingjinliang's avatar
xingjinliang committed
1478
        self.post_norm = post_norm
1479
1480
1481
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
1482
        self.drop_path_rate = drop_path_rate
1483
        self.transformer_impl = args.transformer_impl
liangjing's avatar
v1  
liangjing committed
1484
        self.retro_add_retriever = args.retro_add_retriever
1485

1486
        # Store activation checkpoiting flag.
liangjing's avatar
v1  
liangjing committed
1487
1488
1489
        self.recompute_granularity = config.recompute_granularity
        self.recompute_method = config.recompute_method
        self.recompute_num_layers = config.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1490
        self.distribute_saved_activations = \
liangjing's avatar
v1  
liangjing committed
1491
            config.distribute_saved_activations and not config.sequence_parallel
1492

liangjing's avatar
v1  
liangjing committed
1493
        self.sequence_parallel = config.sequence_parallel
1494

1495
        # Transformer Engine Init.
liangjing's avatar
v1  
liangjing committed
1496
1497
1498
        self.transformer_engine_v_0_10 = False
        self.transformer_engine_v_0_11 = False
        self.transformer_engine_v_0_8 = False
1499
1500
1501
        if self.transformer_impl == 'transformer_engine':
            global transformer_engine
            import transformer_engine
liangjing's avatar
v1  
liangjing committed
1502

xingjinliang's avatar
xingjinliang committed
1503
            if core.utils.is_te_min_version("0.8.0"):
liangjing's avatar
v1  
liangjing committed
1504
                self.transformer_engine_v_0_8 = True
xingjinliang's avatar
xingjinliang committed
1505
            if core.utils.is_te_min_version("0.10.0"):
liangjing's avatar
v1  
liangjing committed
1506
                self.transformer_engine_v_0_10 = True
xingjinliang's avatar
xingjinliang committed
1507
            if core.utils.is_te_min_version("0.11.0"):
liangjing's avatar
v1  
liangjing committed
1508
1509
                self.transformer_engine_v_0_11 = True

xingjinliang's avatar
xingjinliang committed
1510
1511
            assert not args.squared_relu, ("TransformerEngine does not support squared "
                                           "relu activation.")
liangjing's avatar
v1  
liangjing committed
1512
1513

        self.use_fp8 = args.fp8 is not None
1514
        self.fp8_recipe = None
1515
        self.fp8_group = None
1516
        if self.use_fp8:
liangjing's avatar
v1  
liangjing committed
1517
1518
            assert args.transformer_impl == 'transformer_engine', \
                'transformer-engine required for fp8 training and inference'
xingjinliang's avatar
xingjinliang committed
1519
            self.fp8_group = mpu.get_amax_reduction_group(tp_only_amax_red=config.tp_only_amax_red)
liangjing's avatar
v1  
liangjing committed
1520
            if args.fp8 == "e4m3":
1521
                fp8_format = transformer_engine.common.recipe.Format.E4M3
liangjing's avatar
v1  
liangjing committed
1522
            elif args.fp8 == "hybrid":
1523
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
liangjing's avatar
v1  
liangjing committed
1524
1525
            else:
                raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
            self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
                margin=args.fp8_margin,
                interval=args.fp8_interval,
                fp8_format=fp8_format,
                amax_history_len=args.fp8_amax_history_len,
                amax_compute_algo=args.fp8_amax_compute_algo,
                override_linear_precision=(False, False, not args.fp8_wgrad),
            )

        self.num_microbatches_in_previous_step = -1
        self.microbatch_count = 0
liangjing's avatar
v1  
liangjing committed
1537
        self.checkpoint_core_attention = config.recompute_granularity == 'selective'
1538

1539
        # Number of layers.
liangjing's avatar
v1  
liangjing committed
1540
1541
1542
1543
1544
1545
        self.num_layers = _get_num_layers(args, model_type,
                                          layer_type==LayerType.decoder)

        self.drop_path_rates = [
            rate.item() for rate in
            torch.linspace(0, self.drop_path_rate, config.num_layers)]
Mohammad's avatar
Mohammad committed
1546

liangjing's avatar
v1  
liangjing committed
1547
1548
1549
1550
1551
1552
1553
        self.retro_layer_numbers = None
        if model_type == ModelType.retro_decoder:
            retro_layer_start = 6 if config.num_layers <= 15 else 9
            self.retro_layer_numbers = \
                np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
        if model_type == ModelType.retro_encoder:
            self.retro_layer_numbers = [1]
1554

Mohammad's avatar
Mohammad committed
1555
        # Transformer layers.
liangjing's avatar
v1  
liangjing committed
1556
1557
1558
1559
1560
        if args.retro_add_retriever:
            assert self.recompute_granularity != 'full', \
                "Full recompute not supported for Retro."
            assert args.transformer_impl == 'local', \
                "Transformer engine does not support Retro layers."
Mohammad's avatar
Mohammad committed
1561
        def build_layer(layer_number):
1562
            if args.transformer_impl == 'local':
liangjing's avatar
v1  
liangjing committed
1563
1564
1565
                current_layer_type = _get_layer_type(
                    model_type, layer_type, self.retro_layer_numbers,
                    layer_number)
1566
                return ParallelTransformerLayer(
liangjing's avatar
v1  
liangjing committed
1567
                    config,
1568
                    layer_number,
liangjing's avatar
v1  
liangjing committed
1569
                    layer_type=current_layer_type,
1570
1571
1572
                    self_attn_mask_type=self_attn_mask_type,
                    drop_path_rate=self.drop_path_rates[layer_number - 1])
            else:
liangjing's avatar
v1  
liangjing committed
1573
1574
1575
1576
1577
1578
1579
1580
                # This argument is only available from TE v0.10 onwards.
                extra_transformer_engine_kwargs = {}
                if self.transformer_engine_v_0_8:
                    extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
                if self.transformer_engine_v_0_10:
                    extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
                if self.transformer_engine_v_0_11:
                    extra_transformer_engine_kwargs["normalization"] = args.normalization
xingjinliang's avatar
xingjinliang committed
1581
1582
1583
1584
1585
                assert config.attention_softmax_in_fp32, "TransformerEngine only supports softmax compute in FP32."
                assert (
                    (bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and args.fp16) == config.apply_query_key_layer_scaling
                ), ("Unsupported config for apply_query_key_layer_scaling in TransformerEngine. If --apply-query-key-layer-scaling is "
                    "provided, set env-var NVTE_APPLY_QK_LAYER_SCALING=1 and you must be using fp16.")
1586
                return transformer_engine.pytorch.TransformerLayer(
liangjing's avatar
v1  
liangjing committed
1587
1588
1589
1590
1591
1592
1593
1594
                    config.hidden_size,
                    config.ffn_hidden_size,
                    config.num_attention_heads,
                    layernorm_epsilon=config.layernorm_epsilon,
                    hidden_dropout=config.hidden_dropout,
                    attention_dropout=config.attention_dropout,
                    init_method=config.init_method,
                    output_layer_init_method=config.output_layer_init_method,
1595
                    layer_number=layer_number,
liangjing's avatar
v1  
liangjing committed
1596
                    kv_channels=config.kv_channels,
1597
                    self_attn_mask_type=self_attn_mask_type.name,
xingjinliang's avatar
xingjinliang committed
1598
1599
1600
1601
1602
                    tp_group=mpu.get_tensor_model_parallel_group() if mpu.is_initialized() else None,
                    tp_size=mpu.get_tensor_model_parallel_world_size(),
                    get_rng_state_tracker=get_cuda_rng_tracker
                    if get_cuda_rng_tracker().is_initialized()
                    else None,
liangjing's avatar
v1  
liangjing committed
1603
                    fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
1604
1605
                    seq_length=args.seq_length,
                    micro_batch_size=args.micro_batch_size,
liangjing's avatar
v1  
liangjing committed
1606
1607
1608
                    sequence_parallel=config.sequence_parallel,
                    params_dtype=config.params_dtype,
                    apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
1609
1610
1611
1612
                    output_layernorm=False,
                    layer_type="encoder",
                    drop_path_rate=self.drop_path_rates[layer_number - 1],
                    set_parallel_mode=True,
liangjing's avatar
v1  
liangjing committed
1613
1614
                    fuse_qkv_params=True,
                    **extra_transformer_engine_kwargs)
1615

liangjing's avatar
v1  
liangjing committed
1616
1617
        if config.virtual_pipeline_model_parallel_size is not None:
            assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
1618
1619
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1620
            assert args.model_type != ModelType.encoder_and_decoder
1621
1622
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
liangjing's avatar
v1  
liangjing committed
1623
            self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
1624
1625
1626
1627
1628
1629
1630
1631
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
1632
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
liangjing's avatar
v1  
liangjing committed
1633
                config.num_layers // config.virtual_pipeline_model_parallel_size) + \
1634
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
1635
        else:
1636
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1637
            if args.model_type == ModelType.encoder_and_decoder and \
1638
1639
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1640
1641
1642
1643
1644
1645
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
1646
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
1647

1648
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
1649
            # When a standalone embedding stage is used (e.g.,
1650
            # args.standalone_embedding_stage == True), virtual pipeline ranks
1651
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
1652
1653
1654
1655
1656
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
1657
1658
1659
1660
1661
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
1662

liangjing's avatar
v1  
liangjing committed
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
            # Update dropout rate for Retro encoder.
            if model_type == ModelType.retro_encoder:
                for layer in self.layers:
                    if layer.self_attention.use_flash_attn:
                        layer.self_attention.core_attention_flash.dropout_p = \
                            torch.nn.Dropout(args.retro_encoder_attention_dropout)
                    else:
                        layer.self_attention.core_attention.attention_dropout.p =\
                            args.retro_encoder_attention_dropout
                    layer.hidden_dropout = args.retro_encoder_hidden_dropout

xingjinliang's avatar
xingjinliang committed
1674
        if self.post_process and self.post_norm:
1675
            # Final layer norm before output.
xingjinliang's avatar
xingjinliang committed
1676
            self.final_norm = get_norm(config)
1677

Mohammad's avatar
Mohammad committed
1678
    def _get_layer(self, layer_number):
1679
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
1680

1681
    def _checkpointed_forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1682
1683
                              encoder_output, enc_dec_attn_mask,
                              rotary_pos_emb, is_first_microbatch):
1684
        """Forward method with activation checkpointing."""
liangjing's avatar
v1  
liangjing committed
1685
        def custom(start, end):
1686
            def custom_forward(*args, **kwargs):
1687
                x_, *args = args
Mohammad's avatar
Mohammad committed
1688
1689
                for index in range(start, end):
                    layer = self._get_layer(index)
1690
                    x_ = layer(x_, *args, **kwargs)
1691
                return x_
liangjing's avatar
v1  
liangjing committed
1692
1693
1694
1695
1696
1697
1698
            return custom_forward

        te_forward_kwargs = {}
        if self.transformer_impl == 'transformer_engine':
            te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
            if self.transformer_engine_v_0_10:
                te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
1699

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1700
        if self.recompute_method == 'uniform':
liangjing's avatar
v1  
liangjing committed
1701
1702
            # Uniformly divide the total number of Transformer layers and
            # checkpoint the input activation of each divided chunk.
1703
1704
1705
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
1706
                if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1707
1708
                    hidden_states = transformer_engine.pytorch.checkpoint(
                        custom(l, l + self.recompute_num_layers),
1709
1710
1711
                        self.distribute_saved_activations,
                        tensor_parallel.get_cuda_rng_tracker,
                        mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1712
                        hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1713
                        enc_dec_attn_mask, **te_forward_kwargs)
1714
1715
1716
1717
                else:
                    hidden_states = tensor_parallel.checkpoint(
                        custom(l, l + self.recompute_num_layers),
                        self.distribute_saved_activations,
liangjing's avatar
v1  
liangjing committed
1718
1719
1720
                        hidden_states, attention_mask,
                        encoder_output, enc_dec_attn_mask,
                        None, None, None, None, rotary_pos_emb)
1721

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1722
                l += self.recompute_num_layers
1723

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1724
        elif self.recompute_method == 'block':
1725
1726
1727
1728
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1729
                if l < self.recompute_num_layers:
1730
                    if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1731
1732
                        hidden_states = transformer_engine.pytorch.checkpoint(
                            custom(l, l + 1),
1733
1734
1735
                            self.distribute_saved_activations,
                            tensor_parallel.get_cuda_rng_tracker,
                            mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1736
                            hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1737
                            enc_dec_attn_mask, **te_forward_kwargs)
1738
1739
1740
1741
                    else:
                        hidden_states = tensor_parallel.checkpoint(
                            custom(l, l + 1),
                            self.distribute_saved_activations,
liangjing's avatar
v1  
liangjing committed
1742
1743
1744
                            hidden_states, attention_mask,
                            encoder_output, enc_dec_attn_mask,
                            None, None, None, None, rotary_pos_emb)
1745
                else:
1746
                    if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1747
                        hidden_states = custom(l, l + 1)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1748
                            hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1749
                            enc_dec_attn_mask, **te_forward_kwargs)
1750
1751
                    else:
                        hidden_states = custom(l, l + 1)(
liangjing's avatar
v1  
liangjing committed
1752
1753
1754
                            hidden_states, attention_mask,
                            encoder_output, enc_dec_attn_mask,
                            None, None, None, None, rotary_pos_emb)
1755
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1756
            raise ValueError("Invalid activation recompute method.")
1757
1758
1759

        return hidden_states

1760
    def set_input_tensor(self, input_tensor):
1761
1762
1763
1764
1765
1766
1767
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
1768
1769
        self.input_tensor = input_tensor

1770
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1771
                encoder_output=None, enc_dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
1772
1773
1774
1775
1776
                retriever_input=None,
                retriever_output=None,
                retriever_attn_mask=None,
                inference_params=None,
                rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1777
1778
        # hidden_states: [s, b, h]

1779
        # Checks.
mshoeybi's avatar
mshoeybi committed
1780
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1781
            assert self.recompute_granularity is None, \
1782
                'inference does not work with activation checkpointing'
1783

1784
        if not self.pre_process:
1785
            # See set_input_tensor()
1786
            hidden_states = self.input_tensor
1787

1788
1789
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
1790
1791
1792
1793
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
1794
1795
1796
1797
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
1798
        #
1799
1800
1801
1802
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
1803
        hidden_states = core.utils.make_viewless_tensor(
1804
            hidden_states,
1805
1806
            requires_grad=True,
            keep_graph=True,
1807
1808
        )

liangjing's avatar
v1  
liangjing committed
1809
        # RNG context.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1810
        if self.sequence_parallel:
1811
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
1812
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1813
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1814

liangjing's avatar
v1  
liangjing committed
1815
        # Forward layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1816
        with rng_context:
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
            # The fp8_autocast context manager is a no-op when enabled=True
            # The if...else serves to short circuit name resolution for fp8_autocast
            with transformer_engine.pytorch.fp8_autocast(
                enabled=self.use_fp8,
                fp8_recipe=self.fp8_recipe,
                fp8_group=self.fp8_group
            ) if self.use_fp8 else nullcontext():
                # Determine if the current iteration is first microbatch
                if self.num_microbatches_in_previous_step != get_num_microbatches():
                    self.microbatch_count = 0 # Reset count on new batch size rampup interval
                self.num_microbatches_in_previous_step = get_num_microbatches()
                is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

                # Forward pass.
                if self.recompute_granularity == 'full':
                    hidden_states = self._checkpointed_forward(hidden_states,
                                                               attention_mask,
                                                               encoder_output,
                                                               enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1836
                                                               rotary_pos_emb,
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
                                                               is_first_microbatch)
                else:
                    forward_kwargs = {
                        'encoder_output': encoder_output,
                        'enc_dec_attn_mask': enc_dec_attn_mask,
                        'inference_params': inference_params,
                    }

                    if self.transformer_impl == 'transformer_engine':
                        forward_kwargs['is_first_microbatch'] = is_first_microbatch
                        forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
liangjing's avatar
v1  
liangjing committed
1848
1849
1850
1851
1852
1853
1854
                        if self.transformer_engine_v_0_10:
                            forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
                    else:
                        forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
                        forward_kwargs['retriever_input'] = retriever_input
                        forward_kwargs['retriever_output'] = retriever_output
                        forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
1855
1856
1857
1858
1859
1860
1861
1862
1863

                    for index in range(self.num_layers):
                        layer = self._get_layer(index)

                        hidden_states = layer(
                            hidden_states,
                            attention_mask,
                            **forward_kwargs)

liangjing's avatar
v1  
liangjing committed
1864
1865
1866
1867
1868
1869
1870
1871
                        # First Retro decoder layer returns both hidden_states
                        # and retriever_output. Make retriever_output available
                        # to subsequence Retro layers.
                        if isinstance(hidden_states, tuple):
                            assert len(hidden_states) == 2
                            hidden_states, retriever_output = hidden_states
                            forward_kwargs["retriever_output"] = retriever_output

1872
1873
1874
                # Skip counter update for eval and activation checkpointing
                if torch.is_grad_enabled() and self.training:
                    self.microbatch_count += 1
mshoeybi's avatar
mshoeybi committed
1875

1876
        # Final layer norm.
xingjinliang's avatar
xingjinliang committed
1877
1878
        if self.post_process and self.post_norm:
            hidden_states = self.final_norm(hidden_states)
1879

1880
        return hidden_states
xingjinliang's avatar
xingjinliang committed
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895

    def load_state_dict(self, state_dict, strict=True):
        """Customize load."""

        # Handle renaming layernorm -> norm in component names
        state_dict_ = {}
        for key in state_dict.keys():
            # Bypass TransformerEngine module parameters.
            if "layernorm_qkv" in key or "layernorm_mlp" in key:
                state_dict_[key] = state_dict[key]
                continue
            newkey = key.replace("layernorm", "norm")
            state_dict_[newkey] = state_dict[key]

        super().load_state_dict(state_dict_, strict)