transformer.py 76.5 KB
Newer Older
liangjing's avatar
liangjing committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
3

"""Transformer."""
liangjing's avatar
v1  
liangjing committed
4
import math
liangjing's avatar
liangjing committed
5
6
7
8
import os
from contextlib import nullcontext
from typing import Optional

liangjing's avatar
v1  
liangjing committed
9
import numpy as np
10
import torch
11
import torch.nn.functional as F
12

liangjing's avatar
liangjing committed
13
from megatron import core
14
from megatron.core import mpu, tensor_parallel
15
from megatron.core.enums import ModelType
liangjing's avatar
liangjing committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from megatron.core.jit import jit_fuser
from megatron.core.models.common.embeddings.rotary_pos_embedding import (
    RotaryEmbedding,
    apply_rotary_pos_emb,
)
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.parallel_state import (
    get_tensor_and_expert_parallel_group,
    get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import (
    gather_from_sequence_parallel_region_to_moe,
    get_cuda_rng_tracker,
    get_data_parallel_rng_tracker_name,
    reduce_scatter_to_sequence_parallel_region_from_moe,
)
from megatron.legacy.model.enums import AttnMaskType, AttnType, LayerType
from megatron.legacy.model.fused_bias_gelu import bias_gelu_impl
from megatron.legacy.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.legacy.model.utils import (
    attention_mask_func,
    erf_gelu,
    get_norm,
    openai_gelu,
)
from megatron.training import get_args, get_timers

from .module import MegatronModule
44

45
46
47
48
49
50
51
52
try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
liangjing's avatar
v1  
liangjing committed
53
    try:
liangjing's avatar
liangjing committed
54
55
56
        from flash_attn.flash_attn_interface import (
            flash_attn_varlen_func as flash_attn_unpadded_func,
        )
liangjing's avatar
v1  
liangjing committed
57
58
    except ImportError:
        flash_attn_unpadded_func = None
59

60
61
62
63
64
65
66
67
68
69
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
70
    Transformer takes input of size [s, b, h] and returns a
71
72
73
74
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

75
class DropPath(MegatronModule):
76
    """Drop paths (Stochastic Depth) per sample
77
78
79
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
80
    def __init__(self, drop_prob=0.):
81
82
83
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
84
    def forward(self, hidden_state):
85
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
86
            return hidden_state
87
88
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
89
90
        # hidden_state: [s, b, h]
        shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
91
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
92
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
93
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
94
        output = hidden_state.div(keep_prob) * random_tensor
95
96
        return output

97
98
99
100
101
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
102
    state back into h hidden dimension.
103
104
    """

liangjing's avatar
liangjing committed
105
    def __init__(self, config, is_expert=False):
106
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
107
        args = get_args()
108

liangjing's avatar
v1  
liangjing committed
109
110
111
112
113
        self.add_bias = config.add_bias_linear

        ffn_hidden_size = config.ffn_hidden_size
        if config.gated_linear_unit:
            ffn_hidden_size *= 2
114

115
        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
116
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
liangjing's avatar
v1  
liangjing committed
117
118
119
120
            config.hidden_size,
            ffn_hidden_size,
            config=config,
            init_method=config.init_method,
121
            bias=self.add_bias,
122
            gather_output=False,
123
            skip_bias_add=True,
liangjing's avatar
liangjing committed
124
            is_expert=is_expert,
liangjing's avatar
v1  
liangjing committed
125
        )
126

127
128
129
130
        self.bias_gelu_fusion = False
        self.activation_func = None
        self.swiglu = args.swiglu

131
132
133
134
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
135
136
137
138
139
140
141
142
143
144
145
146
        elif args.swiglu:
            def swiglu(x):
                x = torch.chunk(x, 2, dim=-1)
                return F.silu(x[0]) * x[1]
            self.activation_func = swiglu
        elif args.squared_relu:
            def squared_relu(x):
                return torch.pow(F.relu(x), 2)
            self.activation_func = squared_relu
        else:
            self.bias_gelu_fusion = args.bias_gelu_fusion
            self.activation_func = F.gelu
147
148

        # Project back to h.
149
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
liangjing's avatar
v1  
liangjing committed
150
151
152
153
            config.ffn_hidden_size,
            config.hidden_size,
            config=config,
            init_method=config.output_layer_init_method,
154
            bias=self.add_bias,
liangjing's avatar
liangjing committed
155
156
157
            skip_bias_add=True,
            input_is_parallel=True,
            is_expert=is_expert,
liangjing's avatar
v1  
liangjing committed
158
        )
liangjing's avatar
update  
liangjing committed
159

160
161
    def forward(self, hidden_states):

162
163
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
164

165
        if self.bias_gelu_fusion:
166
167
168
            assert self.add_bias is True
            assert self.activation_func == F.gelu
            intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
169
        else:
Jared Casper's avatar
Jared Casper committed
170
            if bias_parallel is not None:
171
172
                intermediate_parallel = intermediate_parallel + bias_parallel
            intermediate_parallel = self.activation_func(intermediate_parallel)
173
174
175
176

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
177

liangjing's avatar
liangjing committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def sinkhorn(cost, tol=0.0001):
    cost = torch.exp(cost)
    d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
    d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)

    eps = 0.00000001
    error = 1e9
    d1_old = d1
    while error > tol:
        d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps)
        d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps)
        error = torch.mean(torch.abs(d1_old-d1))
        d1_old = d1
    return d1*cost*d0.unsqueeze(1)


def get_router_linear_layer(config):
    args = get_args()
    router = torch.nn.Linear(args.hidden_size, args.num_experts, bias=False)
    with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
        config.init_method(router.weight)
    setattr(router.weight, 'sequence_parallel',config.sequence_parallel)
    return router


rprenger's avatar
rprenger committed
203
204
205
206
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
liangjing's avatar
v1  
liangjing committed
207
    def __init__(self, config):
rprenger's avatar
rprenger committed
208
209
        super(SwitchMLP, self).__init__()
        args = get_args()
liangjing's avatar
liangjing committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        self.router = get_router_linear_layer(config)
        self.expert_parallel_size = mpu.get_expert_model_parallel_world_size()
        self.sequence_parallel = config.sequence_parallel
        self.add_bias = config.add_bias_linear

        assert args.num_experts % self.expert_parallel_size == 0
        self.num_local_experts = args.num_experts // self.expert_parallel_size
        local_expert_indices_offset = mpu.get_expert_model_parallel_rank() * self.num_local_experts
        self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)]

        self.local_experts = torch.nn.ModuleList()
        for i in range(self.num_local_experts):
            self.local_experts.append(ParallelMLP(config, is_expert=True))

    def gather_indices(self, local_indices):
        """ Gather tensors and concatinate along the first dimension."""
        group = get_tensor_and_expert_parallel_group()
        world_size = torch.distributed.get_world_size(group=group)
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return local_indices

        dim_size = list(local_indices.size())
        dim_size[0] = dim_size[0] * world_size

        # TODO pre allocate memory
        output = torch.empty(dim_size, dtype=local_indices.dtype,
                             device=torch.cuda.current_device())
        torch.distributed._all_gather_base(
            output, local_indices.contiguous(), group=group
        )
        return output
242

rprenger's avatar
rprenger committed
243
    def forward(self, hidden_states):
liangjing's avatar
liangjing committed
244
245
        # hidden_states: [b, s, h]
        args = get_args()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
246
247
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
248
        h = hidden_states.size(2)
liangjing's avatar
liangjing committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        route = self.router(hidden_states).view(-1, args.num_experts)

        # TODO (rprenger) Right now we're just using the sinkhorn algorithm
        # for load balancing. There should be an option to do no load balancing
        # and the algorithm and parametets should be further tested
        if self.training:
            with torch.no_grad():
                sinkroute = sinkhorn(route.detach().to(dtype=torch.float32))
                _, max_ind = torch.max(sinkroute, dim=1)
            route = torch.sigmoid(route)
            max_prob = route[torch.arange(route.size(0)), max_ind]
        else:
            route = torch.sigmoid(route)
            max_prob, max_ind = torch.max(route, dim=1)

        max_prob = torch.unsqueeze(max_prob, 1)
        hidden_states = hidden_states.view(-1, hidden_states.size(2))
266

rprenger's avatar
rprenger committed
267
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
268
        # Converting [s, b, h] to [s*b, h].
269
        # Each vector could be routed differently
liangjing's avatar
liangjing committed
270
271
272
273
274
275
276
        if self.sequence_parallel or (self.expert_parallel_size > 1):
            global_hidden_states = \
                gather_from_sequence_parallel_region_to_moe(hidden_states)
            global_indices = self.gather_indices(max_ind)
        else:
            global_hidden_states = hidden_states
            global_indices = max_ind
rprenger's avatar
rprenger committed
277

liangjing's avatar
liangjing committed
278
279
280
        output_total = torch.zeros_like(global_hidden_states)
        if self.add_bias:
            output_bias_total = torch.zeros_like(global_hidden_states)
281

liangjing's avatar
liangjing committed
282
283
284
285
        for expert_num, expert in enumerate(self.local_experts):
            local_expert_index = self.local_expert_indices[expert_num]
            local_indices = (global_indices == local_expert_index).nonzero()
            hidden = global_hidden_states[local_indices, :]
rprenger's avatar
rprenger committed
286
            output, output_bias = expert(hidden)
liangjing's avatar
liangjing committed
287
288
            output_total[local_indices, :] = output
            if self.add_bias:
liangjing's avatar
v1  
liangjing committed
289
                output_bias = output_bias.expand_as(output)
liangjing's avatar
liangjing committed
290
291
292
293
294
295
296
297
298
299
300
301
302
                output_bias_total[local_indices, :] = output_bias

        if self.sequence_parallel or (self.expert_parallel_size > 1):
            output_total = \
                reduce_scatter_to_sequence_parallel_region_from_moe(output_total)
            if self.add_bias:
                output_bias_total = \
                    reduce_scatter_to_sequence_parallel_region_from_moe(output_bias_total)

                # bias is duplicated across tensor parallelism ranks;
                # reduce scatter reduces bias across tensor parallel_ranks
                output_bias_total = \
                    output_bias_total/mpu.get_tensor_model_parallel_world_size()
303

rprenger's avatar
rprenger committed
304
        output_total = output_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
305
        output_total = output_total.view(s, b, h)
liangjing's avatar
liangjing committed
306
        if self.add_bias:
liangjing's avatar
v1  
liangjing committed
307
308
309
310
            output_bias_total = output_bias_total*max_prob
            output_bias_total = output_bias_total.view(s, b, h)
        else:
            output_bias_total = None
rprenger's avatar
rprenger committed
311
312

        return output_total, output_bias_total
313

314
315

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
316

liangjing's avatar
v1  
liangjing committed
317
    def __init__(self, layer_number, config,
318
319
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
liangjing's avatar
v1  
liangjing committed
320
321
        self.fp16 = config.fp16
        self.bf16 = config.bf16
322

liangjing's avatar
v1  
liangjing committed
323
324
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
325
326
327
328
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
liangjing's avatar
v1  
liangjing committed
329
        self.sequence_parallel = config.sequence_parallel
330

liangjing's avatar
v1  
liangjing committed
331
        projection_size = config.kv_channels * config.num_attention_heads
332
333

        # Per attention head and per partition values.
334
        world_size = mpu.get_tensor_model_parallel_world_size()
335
336
337
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
338
            projection_size, config.num_attention_heads)
339
        self.num_attention_heads_per_partition = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
340
            config.num_attention_heads, world_size)
341
342
343
344
345
346
347
348
349
350

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
liangjing's avatar
v1  
liangjing committed
351
            config.masked_softmax_fusion,
352
353
354
355
356
357
358
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
liangjing's avatar
v1  
liangjing committed
359
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
liangjing's avatar
v1  
liangjing committed
375
376
        query_layer = query_layer.reshape(output_size[2],
                                          output_size[0] * output_size[1], -1)
377
378
379
380
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
381
        # preallocting input tensor: [b * np, sq, sk]
382
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
383
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
384
            query_layer.dtype, "mpu")
385
386
387

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
388
            matmul_input_buffer,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
406
        if not self.sequence_parallel:
407
            with tensor_parallel.get_cuda_rng_tracker().fork():
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
class FlashSelfAttention(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
                                                      'e.g., with pip install flash-attn')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
Jimmy Zhang's avatar
Jimmy Zhang committed
476
477
478

        assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
        assert all((i.is_cuda for i in (q,k,v)))
Jimmy Zhang's avatar
Jimmy Zhang committed
479
480

        batch_size, seqlen_q = q.shape[0], q.shape[1]
Jimmy Zhang's avatar
Jimmy Zhang committed
481
        seqlen_k = k.shape[1]
Jimmy Zhang's avatar
Jimmy Zhang committed
482

Jimmy Zhang's avatar
Jimmy Zhang committed
483
484
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
Jimmy Zhang's avatar
Jimmy Zhang committed
485
486
                                    device=q.device)

Jimmy Zhang's avatar
Jimmy Zhang committed
487
488
489
490
491
492
        if self.training:
            # during training q,k,v always have same seqlen
            assert seqlen_k == seqlen_q

            is_causal = self.causal
            cu_seqlens_k = cu_seqlens_q
liangjing's avatar
v1  
liangjing committed
493
            dropout_p = self.dropout_p
Jimmy Zhang's avatar
Jimmy Zhang committed
494
        else:
Jimmy Zhang's avatar
Jimmy Zhang committed
495
            # turn off FA causal mask after first inference autoregressive iteration
Jimmy Zhang's avatar
Jimmy Zhang committed
496
            # only on first autoregressive step q,k,v have same seqlen
Jimmy Zhang's avatar
Jimmy Zhang committed
497
498
            is_causal = seqlen_q == seqlen_k
            cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
Jimmy Zhang's avatar
Jimmy Zhang committed
499
                        device=q.device)
liangjing's avatar
v1  
liangjing committed
500
            dropout_p = 0
Jimmy Zhang's avatar
Jimmy Zhang committed
501

Jimmy Zhang's avatar
Jimmy Zhang committed
502
503
        output = flash_attn_unpadded_func(
            q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
liangjing's avatar
v1  
liangjing committed
504
            dropout_p,
Jimmy Zhang's avatar
Jimmy Zhang committed
505
506
            softmax_scale=self.softmax_scale, causal=is_causal
        )
Jimmy Zhang's avatar
Jimmy Zhang committed
507

508
509
510
511
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output


512
class ParallelAttention(MegatronModule):
513
514
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
515
    Self-attention layer takes input with size [s, b, h]
516
517
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
518

liangjing's avatar
v1  
liangjing committed
519
    def __init__(self, config, layer_number,
520
521
522
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
523
        args = get_args()
524
        self.layer_number = max(1, layer_number)
525
526
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
liangjing's avatar
v1  
liangjing committed
527
528
        self.params_dtype = config.params_dtype
        self.sequence_parallel = config.sequence_parallel
liangjing's avatar
liangjing committed
529
        self.config = config
liangjing's avatar
v1  
liangjing committed
530
531
532
533
534
535
536
537
        self.group_query_attention = args.group_query_attention
        self.num_query_groups = args.num_query_groups

        query_projection_size = config.kv_channels * config.num_attention_heads
        if self.group_query_attention:
            kv_projection_size = args.kv_channels * args.num_query_groups
        else:
            kv_projection_size = args.kv_channels * args.num_attention_heads
538

liangjing's avatar
update  
liangjing committed
539
        self.use_flash_attn = args.use_flash_attn \
liangjing's avatar
v1  
liangjing committed
540
541
            and attention_type == AttnType.self_attn \
            and self.attn_mask_type == AttnMaskType.causal
542
        if self.use_flash_attn:
liangjing's avatar
update  
liangjing committed
543
544
            if flash_attn_unpadded_func is None:
                raise ImportError('FlashAttention is not installed, please install with '
545
546
547
548
549
550
551
                                  'pip install flash-attn')
            assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
                                                          'self-attention for now')
            assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
                                                                'supports causal mask for now')
            if rearrange is None:
                raise ImportError('einops is not installed, please install with pip install einops')
552

553
        # Per attention head and per partition values.
554
        world_size = mpu.get_tensor_model_parallel_world_size()
555
        self.hidden_size_per_attention_head = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
556
            query_projection_size, config.num_attention_heads)
557
        self.num_attention_heads_per_partition = core.utils.divide(
liangjing's avatar
v1  
liangjing committed
558
559
560
561
562
563
564
565
566
567
            config.num_attention_heads, world_size)

        if self.group_query_attention:
            if args.num_query_groups % world_size != 0:
                raise NotImplementedError('Currently the num_query_groups should be '
                                          'a multiple of the tensor parallel size')
            self.num_query_groups_per_partition = core.utils.divide(
                        args.num_query_groups, world_size)
        else:
            self.num_query_groups_per_partition = self.num_attention_heads_per_partition
568
569

        # Strided linear layer.
570
        if attention_type == AttnType.self_attn:
571
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
liangjing's avatar
v1  
liangjing committed
572
573
574
575
                config.hidden_size,
                query_projection_size + 2 * kv_projection_size,
                config=config,
                init_method=config.init_method,
liangjing's avatar
liangjing committed
576
                bias=args.add_bias_linear or args.add_qkv_bias,
liangjing's avatar
v1  
liangjing committed
577
                gather_output=False)
578
579
580
        else:
            assert attention_type == AttnType.cross_attn

liangjing's avatar
v1  
liangjing committed
581
582
583
            if self.group_query_attention:
                raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
            assert query_projection_size == kv_projection_size
584

liangjing's avatar
v1  
liangjing committed
585
586
587
588
589
590
591
            self.query = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                query_projection_size,
                config=config,
                init_method=config.init_method,
                bias=config.add_bias_linear,
                gather_output=False)
592

liangjing's avatar
v1  
liangjing committed
593
594
595
596
597
598
599
600
601
            self.key_value = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                2 * kv_projection_size,
                config=config,
                init_method=config.init_method,
                bias=config.add_bias_linear,
                gather_output=False)

        self.core_attention = CoreAttention(self.layer_number, config,
602
                                            self.attn_mask_type)
liangjing's avatar
v1  
liangjing committed
603
        self.checkpoint_core_attention = config.recompute_granularity == 'selective'
604

liangjing's avatar
update  
liangjing committed
605
        if self.use_flash_attn:
606
            self.core_attention_flash = FlashSelfAttention(
liangjing's avatar
v1  
liangjing committed
607
                causal=True, attention_dropout=config.attention_dropout
608
609
            )

610
        # Output.
611
        self.dense = tensor_parallel.RowParallelLinear(
liangjing's avatar
v1  
liangjing committed
612
613
614
615
            query_projection_size,
            config.hidden_size,
            config=config,
            init_method=config.output_layer_init_method,
616
            bias=args.add_bias_linear,
617
            input_is_parallel=True,
liangjing's avatar
v1  
liangjing committed
618
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
619

620
    def _checkpointed_attention_forward(self, query_layer, key_layer,
Mostofa Patwary's avatar
Mostofa Patwary committed
621
622
                                        value_layer, attention_mask,
                                        rotary_pos_emb=None):
623
624
625
626
627
628
629
630
631
632
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

Mostofa Patwary's avatar
Mostofa Patwary committed
633
634
635
        q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
            else rotary_pos_emb

636
        hidden_states = tensor_parallel.checkpoint(
637
            custom_forward,
Mostofa Patwary's avatar
Mostofa Patwary committed
638
639
            False, query_layer, key_layer, value_layer, attention_mask,
            q_pos_emb, k_pos_emb)
640
641

        return hidden_states
642

liangjing's avatar
v1  
liangjing committed
643
    def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
644
645
646
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
liangjing's avatar
v1  
liangjing committed
647
            num_attention_heads,
648
649
650
651
652
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
653
654
                encoder_output=None, inference_params=None,
                rotary_pos_emb=None):
655
        # hidden_states: [sq, b, h]
656

657
658
659
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
Mostofa Patwary's avatar
Mostofa Patwary committed
660
        is_first_step = False
mshoeybi's avatar
mshoeybi committed
661
        if inference_params:
662
            if self.layer_number not in inference_params.key_value_memory_dict:
liangjing's avatar
v1  
liangjing committed
663
                inf_max_seq_len = inference_params.max_sequence_length
mshoeybi's avatar
mshoeybi committed
664
                inf_max_batch_size = inference_params.max_batch_size
665
                inference_key_memory = self._allocate_memory(
liangjing's avatar
v1  
liangjing committed
666
667
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)
668
                inference_value_memory = self._allocate_memory(
liangjing's avatar
v1  
liangjing committed
669
670
671
                    inf_max_seq_len, inf_max_batch_size,
                    self.num_query_groups_per_partition)

672
673
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
Mostofa Patwary's avatar
Mostofa Patwary committed
674
                is_first_step = True
675
676
677
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
678

679
680
681
        # =====================
        # Query, Key, and Value
        # =====================
682
        if self.attention_type == AttnType.self_attn:
liangjing's avatar
liangjing committed
683

liangjing's avatar
v1  
liangjing committed
684
            # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
685
686
            mixed_x_layer, _ = self.query_key_value(hidden_states)

liangjing's avatar
v1  
liangjing committed
687
688
689
690
691
692
693
694
            # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + (
                self.num_query_groups_per_partition,
                (
                    (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
                    * self.hidden_size_per_attention_head
                ),
            )
695
696
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

liangjing's avatar
v1  
liangjing committed
697
            # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
698
            (query_layer,
liangjing's avatar
v1  
liangjing committed
699
700
701
702
703
704
705
706
707
708
709
710
            key_layer,
            value_layer) = torch.split(
                mixed_x_layer,
                [
                    (
                        self.num_attention_heads_per_partition // self.num_query_groups_per_partition
                        * self.hidden_size_per_attention_head
                    ),
                    self.hidden_size_per_attention_head,
                    self.hidden_size_per_attention_head
                ],
                dim=3)
liangjing's avatar
liangjing committed
711

liangjing's avatar
v1  
liangjing committed
712
            # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
liangjing's avatar
update  
liangjing committed
713
            query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
714
715
716
717
718
719
720
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
liangjing's avatar
v1  
liangjing committed
721
                2 * self.hidden_size_per_attention_head)
722
723
724
725
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
liangjing's avatar
v1  
liangjing committed
726
            value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
727
728
729
730
731
732

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
liangjing's avatar
v1  
liangjing committed
733
                self.hidden_size_per_attention_head)
734
            query_layer = query_layer.view(*new_tensor_shape)
735

mshoeybi's avatar
mshoeybi committed
736
737
738
        # ==================================
        # Adjust key and value for inference
        # ==================================
739

Mostofa Patwary's avatar
Mostofa Patwary committed
740
741
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
742
743
744
745
            if isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = rotary_pos_emb
            else:
                rotary_pos_emb = ((rotary_pos_emb,) * 2)
Mostofa Patwary's avatar
Mostofa Patwary committed
746

mshoeybi's avatar
mshoeybi committed
747
        if inference_params:
mshoeybi's avatar
mshoeybi committed
748
749
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
750
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
751
752
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
753
            assert sequence_end <= inference_key_memory.size(0)
754
            # Copy key and values.
755
756
757
758
759
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
760
                :sequence_end, batch_start:batch_end, ...]
761
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
762
                :sequence_end, batch_start:batch_end, ...]
763

Mostofa Patwary's avatar
Mostofa Patwary committed
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784

            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)

785
786
787
        # ==================================
        # core attention computation
        # ==================================
788

liangjing's avatar
v1  
liangjing committed
789
        # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
liangjing's avatar
liangjing committed
790
791
792
793
794
795
796
797
798
        if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
            key_layer = key_layer.repeat_interleave(
                self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
                dim = 2
            )
            value_layer = value_layer.repeat_interleave(
                self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
                dim = 2
            )
liangjing's avatar
v1  
liangjing committed
799

Mostofa Patwary's avatar
Mostofa Patwary committed
800
801
802
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
liangjing's avatar
liangjing committed
803
804
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb,self.config)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb,self.config)
Mostofa Patwary's avatar
Mostofa Patwary committed
805
806
807
808
809
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

810
811
812
813
814
815
816
        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
817
        else:
liangjing's avatar
update  
liangjing committed
818
            q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
819
820
821
                       for x in (query_layer, key_layer, value_layer)]
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
liangjing's avatar
update  
liangjing committed
822
                    context_layer = self.core_attention_flash(q, k, v)
823
            else:
liangjing's avatar
update  
liangjing committed
824
825
                context_layer = self.core_attention_flash(q, k, v)
            context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
826
827

        # =================
828
        # Output. [sq, b, h]
829
830
831
        # =================

        output, bias = self.dense(context_layer)
832

833
834
835
        return output, bias


836
def bias_dropout_add(x, bias, residual, prob, training):
Jared Casper's avatar
Jared Casper committed
837
    # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
838
839
840
    if bias is not None:
        x = x + bias
    out = torch.nn.functional.dropout(x, p=prob, training=training)
841
842
843
844
845
846
847
848
849
850
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


liangjing's avatar
liangjing committed
851
@jit_fuser
852
def bias_dropout_add_fused_train(x: torch.Tensor,
Jared Casper's avatar
Jared Casper committed
853
                                 bias: Optional[torch.Tensor],
854
855
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
856
857
858
    return bias_dropout_add(x, bias, residual, prob, True)


liangjing's avatar
liangjing committed
859
@jit_fuser
860
def bias_dropout_add_fused_inference(x: torch.Tensor,
Jared Casper's avatar
Jared Casper committed
861
                                     bias: Optional[torch.Tensor],
862
863
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
864
    return bias_dropout_add(x, bias, residual, prob, False)
865
866
867
868
869


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
870
    Transformer layer takes input with size [s, b, h] and returns an
871
872
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
873

liangjing's avatar
v1  
liangjing committed
874
    def __init__(self, config,
875
                 layer_number, layer_type=LayerType.encoder,
876
877
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
878
        args = get_args()
879
880

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
881
        self.layer_number = layer_number
882
        self.layer_type = layer_type
883

liangjing's avatar
liangjing committed
884
        self.apply_residual_connection_post_norm \
liangjing's avatar
v1  
liangjing committed
885
            = config.apply_residual_connection_post_layernorm
886

liangjing's avatar
v1  
liangjing committed
887
888
        self.bf16 = config.bf16
        self.fp32_residual_connection = config.fp32_residual_connection
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
889

liangjing's avatar
liangjing committed
890
891
        # Normalize the input data.
        self.input_norm = get_norm(config)
892
893

        # Self attention.
894
        self.self_attention = ParallelAttention(
liangjing's avatar
v1  
liangjing committed
895
            config,
896
897
898
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
liangjing's avatar
v1  
liangjing committed
899
900
        self.hidden_dropout = config.hidden_dropout
        self.bias_dropout_fusion = config.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
901
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
902

liangjing's avatar
liangjing committed
903
904
        # Normalize the attention output
        self.post_attention_norm = get_norm(config)
905

liangjing's avatar
v1  
liangjing committed
906
907
908
909
910
        # Cross attention.
        if self.layer_type in (LayerType.decoder,
                               LayerType.retro_decoder,
                               LayerType.retro_decoder_with_retriever,
                               LayerType.retro_encoder):
911
            self.inter_attention = ParallelAttention(
liangjing's avatar
v1  
liangjing committed
912
                config,
913
914
                layer_number,
                attention_type=AttnType.cross_attn)
liangjing's avatar
liangjing committed
915
916
            # Normalize the attention output.
            self.post_inter_attention_norm = get_norm(config)
917

918
        # MLP
rprenger's avatar
rprenger committed
919
        if args.num_experts is not None:
liangjing's avatar
v1  
liangjing committed
920
            self.mlp = SwitchMLP(config)
rprenger's avatar
rprenger committed
921
        else:
liangjing's avatar
v1  
liangjing committed
922
            self.mlp = ParallelMLP(config)
923

924
925
926
927
928
929
930
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

liangjing's avatar
v1  
liangjing committed
931
932
        if args.retro_add_retriever:
            self.retro_num_neighbors = args.retro_num_neighbors
liangjing's avatar
liangjing committed
933
934
935
            self.retro_chunk_length = args.retro_chunk_length
            self.retro_retrieved_length = \
                args.retro_num_retrieved_chunks * args.retro_chunk_length
liangjing's avatar
v1  
liangjing committed
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952

        # Retriever (bi-directional transformer with cross attention)
        if layer_type == LayerType.retro_decoder_with_retriever:
            self.retriever = ParallelTransformer(
                config=config,
                model_type=ModelType.retro_encoder,
                self_attn_mask_type=AttnMaskType.padding,
                pre_process=True,
                post_process=False,
            )
            self._retriever_key = 'retriever'
        else:
            self.retriever = None

    def default_decoder_cross_attention(self,
                                        encoder_output,
                                        enc_dec_attn_mask,
liangjing's avatar
liangjing committed
953
954
                                        norm_input,
                                        norm_output,
liangjing's avatar
v1  
liangjing committed
955
956
957
958
959
                                        bias_dropout_add_func):
        '''Cross attention for a standard encoder-decoder model.'''

        # Attention.
        attention_output, attention_bias = \
liangjing's avatar
liangjing committed
960
            self.inter_attention(norm_output,
liangjing's avatar
v1  
liangjing committed
961
962
963
964
                                 enc_dec_attn_mask,
                                 encoder_output=encoder_output)

        # Residual connection.
liangjing's avatar
liangjing committed
965
966
        if self.apply_residual_connection_post_norm:
            residual = norm_output
liangjing's avatar
v1  
liangjing committed
967
        else:
liangjing's avatar
liangjing committed
968
            residual = norm_input
liangjing's avatar
v1  
liangjing committed
969
970
971
972
973
974

        if attention_bias is not None:
            attention_bias = attention_bias.expand_as(residual)

        # Bias-dropout-add.
        with self.bias_dropout_add_exec_handler():
liangjing's avatar
liangjing committed
975
            norm_input = bias_dropout_add_func(
liangjing's avatar
v1  
liangjing committed
976
977
978
979
980
                attention_output,
                attention_bias,
                residual,
                self.hidden_dropout)

liangjing's avatar
liangjing committed
981
982
        # Normalize.
        norm_output = self.post_inter_attention_norm(norm_input)
liangjing's avatar
v1  
liangjing committed
983

liangjing's avatar
liangjing committed
984
        return norm_input, norm_output
liangjing's avatar
v1  
liangjing committed
985
986
987

    def retro_encoder_cross_attention(self,
                                      retriever_output,
liangjing's avatar
liangjing committed
988
989
                                      norm_input,
                                      norm_output,
liangjing's avatar
v1  
liangjing committed
990
991
992
993
994
995
996
997
998
999
1000
1001
                                      bias_dropout_add_func):
        """Cross attention for Retro encoder.

        Notation:
            ns : Sequence length.
            bs : Batch size.
            d  : Hidden size.
            l  : Number of chunks per sample (i.e., seq_length/chunk_length).
            k  : Number of neighbors.
            r  : Number of retrieved tokens (neighbors + continuation).
        """

liangjing's avatar
liangjing committed
1002
        ns, bs, d = norm_output.shape # [r, bs * l * k, d]
liangjing's avatar
v1  
liangjing committed
1003
1004

        # Divide sequence dimension into chunks.
liangjing's avatar
liangjing committed
1005
1006
1007
1008
1009
1010
1011
        chunked_outputs = norm_output.reshape(self.retro_retrieved_length,
                                              -1,
                                              self.retro_num_neighbors,
                                              d)
        chunked_outputs_before_norm = \
            norm_input.reshape(self.retro_retrieved_length, -1,
                               self.retro_num_neighbors, d) # [r, bs*l, k, d]
liangjing's avatar
v1  
liangjing committed
1012
1013

        # Per-chunk attention.
liangjing's avatar
liangjing committed
1014
1015
        norm_inputs = []
        norm_outputs = []
liangjing's avatar
v1  
liangjing committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        for k in range(self.retro_num_neighbors):

            # Attention.
            chunked_output = chunked_outputs[:,:,k].contiguous()
            attention_output, attention_bias = \
                self.inter_attention(
                    chunked_output, # Q (neighbor embedding)
                    None,
                    encoder_output=retriever_output) # K, V (hidden act)

            # Residual connection.
liangjing's avatar
liangjing committed
1027
            if self.apply_residual_connection_post_norm:
liangjing's avatar
v1  
liangjing committed
1028
1029
                residual = chunked_output
            else:
liangjing's avatar
liangjing committed
1030
                residual = chunked_outputs_before_norm[:,:,k]
liangjing's avatar
v1  
liangjing committed
1031
1032
1033

            # Re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
liangjing's avatar
liangjing committed
1034
                norm_input = bias_dropout_add_func(
liangjing's avatar
v1  
liangjing committed
1035
1036
1037
1038
                    attention_output,
                    None if attention_bias is None else attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
liangjing's avatar
liangjing committed
1039
                norm_inputs.append(norm_input)
liangjing's avatar
v1  
liangjing committed
1040
1041

            # Layer norm.
liangjing's avatar
liangjing committed
1042
1043
            norm_output = self.post_inter_attention_norm(norm_input)
            norm_outputs.append(norm_output)
liangjing's avatar
v1  
liangjing committed
1044
1045

        # Concatenate layer norms.
liangjing's avatar
liangjing committed
1046
1047
1048
1049
        # norm_input : [r, k * bs * l, d]
        # norm_output : [r, k * bs * l, d]
        norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d)
        norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d)
liangjing's avatar
v1  
liangjing committed
1050

liangjing's avatar
liangjing committed
1051
        return norm_input, norm_output
liangjing's avatar
v1  
liangjing committed
1052
1053
1054
1055
1056

    def retro_decoder_cross_attention(self,
                                      retriever_input,
                                      retriever_output,
                                      retriever_attn_mask,
liangjing's avatar
liangjing committed
1057
1058
                                      norm_input,
                                      norm_output,
liangjing's avatar
v1  
liangjing committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
                                      inference_params,
                                      bias_dropout_add_func):
        """Cross attention for Retro decoder.

        Notation:
            ns : Sequence length.
            bs : Batch size.
            d  : Hidden size.
            l  : Number of chunks per sample (i.e., seq_length/chunk_length).
            m  : Number of tokens per chunk.
            k  : Number of neighbors.
            r  : Number of retrieved tokens (neighbors + continuation).
        """

liangjing's avatar
liangjing committed
1073
        ns, bs, d = norm_output.shape
liangjing's avatar
v1  
liangjing committed
1074
1075
1076
1077
1078
1079
1080
        l = int(np.ceil(ns / self.retro_chunk_length))

        # Retrieve neighbors.
        if self.layer_type == LayerType.retro_decoder_with_retriever:
            first_ns = ns % self.retro_chunk_length
            if first_ns > 0:
                first_chunk, rest_chunk = \
liangjing's avatar
liangjing committed
1081
                    norm_output[:first_ns], norm_output[first_ns:]
liangjing's avatar
v1  
liangjing committed
1082
1083
1084
1085
1086
1087
1088
1089
                first_chunk = torch.nn.functional.pad(
                    first_chunk,
                    (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
                    'constant',
                    0)
                chunked_output = \
                    torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
            else:
liangjing's avatar
liangjing committed
1090
                chunked_output = norm_output # [l * m, bs, d]
liangjing's avatar
v1  
liangjing committed
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            chunked_output = chunked_output \
                .reshape(l, self.retro_chunk_length, bs, d) \
                .permute(1, 2, 0, 3) \
                .reshape(self.retro_chunk_length, bs * l, d) \
                .contiguous()

            # Get Encoder Output
            retriever_output = self.retriever(
                hidden_states=retriever_input,
                attention_mask=retriever_attn_mask,
                retriever_output=chunked_output,
                retriever_attn_mask=retriever_attn_mask,
                inference_params=inference_params) # [r, k * bs * l , d]
            retriever_output = retriever_output.reshape(
                self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]

        # Chunks.
        pad = (ns - 1) % self.retro_chunk_length
liangjing's avatar
liangjing committed
1109
        attending_chunks = norm_output[pad:]
liangjing's avatar
v1  
liangjing committed
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        padded_chunks = torch.nn.functional.pad(
            attending_chunks,
            (0, 0, 0, 0, 0, self.retro_chunk_length - 1),
            'constant', 0)
        padded_chunked_output = padded_chunks \
            .reshape(l, self.retro_chunk_length, bs, d) \
            .permute(1, 2, 0, 3)
        padded_chunked_output = padded_chunked_output.reshape(
            self.retro_chunk_length, bs * l, d).contiguous()

        # Encoder output.
        attention_output, attention_bias = \
            self.inter_attention(padded_chunked_output,
                                 None,
                                 encoder_output=retriever_output)

        # Residual connection.
liangjing's avatar
liangjing committed
1127
1128
        if self.apply_residual_connection_post_norm:
            residual = norm_output
liangjing's avatar
v1  
liangjing committed
1129
        else:
liangjing's avatar
liangjing committed
1130
            residual = norm_input
liangjing's avatar
v1  
liangjing committed
1131
1132
1133

        # Re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
liangjing's avatar
liangjing committed
1134
            norm_input = bias_dropout_add_func(
liangjing's avatar
v1  
liangjing committed
1135
1136
1137
1138
                attention_output,
                None if attention_bias is None else attention_bias.expand_as(attention_output),
                torch.zeros_like(attention_output),
                self.hidden_dropout)
liangjing's avatar
liangjing committed
1139
            norm_input = norm_input \
liangjing's avatar
v1  
liangjing committed
1140
1141
                .reshape(self.retro_chunk_length, bs, l, d) \
                .permute(2, 0, 1, 3) # [l, m, bs, d]
liangjing's avatar
liangjing committed
1142
1143
1144
            norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d)
            norm_input = torch.nn.functional.pad(
                norm_input,
liangjing's avatar
v1  
liangjing committed
1145
1146
                (0, 0, 0, 0, pad, 0),
                'constant', 0)[:ns] # [ns, b, d]
liangjing's avatar
liangjing committed
1147
1148
1149
            # TODO: better redesign with inference param
            args = get_args()
            norm_input = args.retro_attention_gate * norm_input + residual
liangjing's avatar
v1  
liangjing committed
1150
1151

        # Layer norm post the decoder attention
liangjing's avatar
liangjing committed
1152
        norm_output = self.post_inter_attention_norm(norm_input)
liangjing's avatar
v1  
liangjing committed
1153

liangjing's avatar
liangjing committed
1154
        return retriever_output, norm_input, norm_output
liangjing's avatar
v1  
liangjing committed
1155

1156
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1157
                encoder_output=None, enc_dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
1158
1159
1160
1161
1162
                retriever_input=None,
                retriever_output=None,
                retriever_attn_mask=None,
                inference_params=None,
                rotary_pos_emb=None):
liangjing's avatar
liangjing committed
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172

        # Update the params in case the retro param changes during inference
        # TODO: better redesign with inference param
        args = get_args()
        if args.retro_add_retriever:
            self.retro_num_neighbors = args.retro_num_neighbors
            self.retro_chunk_length = args.retro_chunk_length
            self.retro_retrieved_length = \
                args.retro_num_retrieved_chunks * args.retro_chunk_length

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1173
        # hidden_states: [s, b, h]
1174

1175
        # Layer norm at the beginning of the transformer layer.
liangjing's avatar
update  
liangjing committed
1176
        norm_output = self.input_norm(hidden_states)
liangjing's avatar
v1  
liangjing committed
1177

1178
        # Self attention.
1179
        attention_output, attention_bias = \
1180
            self.self_attention(
liangjing's avatar
liangjing committed
1181
                norm_output,
1182
                attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1183
                inference_params=inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
1184
                rotary_pos_emb=rotary_pos_emb)
1185

1186
        # Residual connection.
liangjing's avatar
liangjing committed
1187
1188
        if self.apply_residual_connection_post_norm:
            residual = norm_output
1189
1190
1191
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1192
        if self.drop_path is None:
1193
1194
1195
1196
1197
1198
1199
1200
1201
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
1202
            else:
1203
                bias_dropout_add_func = get_bias_dropout_add(self.training)
1204

1205
1206
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)
1207
            with self.bias_dropout_add_exec_handler():
liangjing's avatar
liangjing committed
1208
                norm_input = bias_dropout_add_func(
1209
                    attention_output,
1210
                    attention_bias,
1211
1212
1213
1214
1215
1216
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
liangjing's avatar
liangjing committed
1217
            norm_input = residual + self.drop_path(out)
1218

1219
        # Layer norm post the self attention.
liangjing's avatar
liangjing committed
1220
        norm_output = self.post_attention_norm(norm_input)
1221

liangjing's avatar
v1  
liangjing committed
1222
1223
1224
1225
        # Cross attention.
        if self.layer_type == LayerType.encoder:
            pass
        elif self.layer_type == LayerType.decoder:
liangjing's avatar
liangjing committed
1226
            norm_input, norm_output = \
liangjing's avatar
v1  
liangjing committed
1227
1228
1229
                self.default_decoder_cross_attention(
                    encoder_output,
                    enc_dec_attn_mask,
liangjing's avatar
liangjing committed
1230
1231
                    norm_input,
                    norm_output,
liangjing's avatar
v1  
liangjing committed
1232
1233
                    bias_dropout_add_func)
        elif self.layer_type == LayerType.retro_encoder:
liangjing's avatar
liangjing committed
1234
            norm_input, norm_output = \
liangjing's avatar
v1  
liangjing committed
1235
1236
                self.retro_encoder_cross_attention(
                    retriever_output,
liangjing's avatar
liangjing committed
1237
1238
                    norm_input,
                    norm_output,
liangjing's avatar
v1  
liangjing committed
1239
1240
1241
                    bias_dropout_add_func)
        elif self.layer_type in (LayerType.retro_decoder,
                                 LayerType.retro_decoder_with_retriever):
liangjing's avatar
liangjing committed
1242
            retriever_output, norm_input, norm_output = \
liangjing's avatar
v1  
liangjing committed
1243
1244
1245
1246
                self.retro_decoder_cross_attention(
                    retriever_input,
                    retriever_output,
                    retriever_attn_mask,
liangjing's avatar
liangjing committed
1247
1248
                    norm_input,
                    norm_output,
liangjing's avatar
v1  
liangjing committed
1249
1250
1251
1252
1253
                    inference_params,
                    bias_dropout_add_func)
        else:
            raise Exception("Unsupported layer type, '%s'." %
                            self.layer_type.name)
1254

1255
        # MLP.
liangjing's avatar
liangjing committed
1256
        mlp_output, mlp_bias = self.mlp(norm_output)
1257

1258
        # Second residual connection.
liangjing's avatar
liangjing committed
1259
1260
        if self.apply_residual_connection_post_norm:
            residual = norm_output
1261
        else:
liangjing's avatar
liangjing committed
1262
            residual = norm_input
1263

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1264
        if self.drop_path is None:
1265
1266
            if mlp_bias is not None:
                mlp_bias = mlp_bias.expand_as(residual)
1267
            with self.bias_dropout_add_exec_handler():
1268
1269
                output = bias_dropout_add_func(
                    mlp_output,
1270
                    mlp_bias,
1271
1272
                    residual,
                    self.hidden_dropout)
1273
1274
1275
1276
1277
1278
1279

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
1280
1281
1282
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
1283

1284
        else:
1285
1286
1287
            if mlp_bias is not None:
                mlp_output = mlp_output + mlp_bias
            out = torch.nn.functional.dropout(mlp_output,
1288
1289
1290
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
1291

liangjing's avatar
v1  
liangjing committed
1292
1293
1294
1295
        if self.layer_type == LayerType.retro_decoder_with_retriever:
            return output, retriever_output
        else:
            return output
1296
1297


1298
1299
1300
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
1301
    The sole purpose of this layer is for when a standalone embedding layer
1302
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
1303
1304
1305
1306
1307
1308
1309
1310
1311
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


liangjing's avatar
v1  
liangjing committed
1324
def _get_num_layers(args, model_type, is_decoder=False):
1325
    """Compute the number of transformer layers resident on the current rank."""
liangjing's avatar
v1  
liangjing committed
1326
1327
1328
1329
    is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
    if model_type == ModelType.retro_encoder:
        num_layers = args.retro_encoder_layers
    elif mpu.get_pipeline_model_parallel_world_size() > 1:
liangjing's avatar
liangjing committed
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        assert not is_encoder_and_decoder_model, "This is no longer supported."
        assert args.num_layers == args.encoder_num_layers
        assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
            'num_layers must be divisible by transformer_pipeline_model_parallel_size'

        # When a standalone embedding stage is used, all transformer layers
        # are divided among pipeline rank >= 1, while on pipeline rank 0,
        # ranks either contain the input embedding layer (virtual pp rank 0),
        # or no layers at all (virtual pp rank >= 1).
        num_layers = (
            0
            if args.standalone_embedding_stage
            and mpu.get_pipeline_model_parallel_rank() == 0 else
            args.num_layers // args.transformer_pipeline_model_parallel_size
        )
1345
    else:
Jared Casper's avatar
Jared Casper committed
1346
1347
1348
1349
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
1350
1351
1352
    return num_layers


liangjing's avatar
v1  
liangjing committed
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
                    layer_number):
    args = get_args()
    if args.retro_add_retriever and layer_number in retro_layer_numbers:
        if model_type == ModelType.retro_decoder:
            return LayerType.retro_decoder_with_retriever \
                if layer_number == retro_layer_numbers[0] \
                   else LayerType.retro_decoder
        elif model_type == ModelType.retro_encoder:
            return LayerType.retro_encoder
        else:
            raise Exception("Unsupported model type, '%s'." % model_type)
    else:
        return default_layer_type


1369
1370
1371
class ParallelTransformer(MegatronModule):
    """Transformer class."""

liangjing's avatar
v1  
liangjing committed
1372
1373
    def __init__(self, config,
                 model_type, layer_type=LayerType.encoder,
1374
                 self_attn_mask_type=AttnMaskType.padding,
liangjing's avatar
liangjing committed
1375
                 post_norm=True,
liangjing's avatar
v1  
liangjing committed
1376
1377
                 pre_process=True,
                 post_process=True,
1378
                 drop_path_rate=0.0):
1379
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
1380
        args = get_args()
1381

1382
        self.layer_type = layer_type
liangjing's avatar
v1  
liangjing committed
1383
1384
1385
        self.model_type = model_type
        self.bf16 = config.bf16
        self.fp32_residual_connection = config.fp32_residual_connection
liangjing's avatar
liangjing committed
1386
        self.post_norm = post_norm
1387
1388
1389
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
1390
        self.drop_path_rate = drop_path_rate
1391
        self.transformer_impl = args.transformer_impl
liangjing's avatar
v1  
liangjing committed
1392
        self.retro_add_retriever = args.retro_add_retriever
1393

1394
        # Store activation checkpoiting flag.
liangjing's avatar
v1  
liangjing committed
1395
1396
1397
        self.recompute_granularity = config.recompute_granularity
        self.recompute_method = config.recompute_method
        self.recompute_num_layers = config.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1398
        self.distribute_saved_activations = \
liangjing's avatar
v1  
liangjing committed
1399
            config.distribute_saved_activations and not config.sequence_parallel
1400

liangjing's avatar
v1  
liangjing committed
1401
        self.sequence_parallel = config.sequence_parallel
1402

1403
        # Transformer Engine Init.
liangjing's avatar
v1  
liangjing committed
1404
1405
1406
        self.transformer_engine_v_0_10 = False
        self.transformer_engine_v_0_11 = False
        self.transformer_engine_v_0_8 = False
1407
1408
1409
        if self.transformer_impl == 'transformer_engine':
            global transformer_engine
            import transformer_engine
liangjing's avatar
v1  
liangjing committed
1410

liangjing's avatar
liangjing committed
1411
            if core.utils.is_te_min_version("0.8.0"):
liangjing's avatar
v1  
liangjing committed
1412
                self.transformer_engine_v_0_8 = True
liangjing's avatar
liangjing committed
1413
            if core.utils.is_te_min_version("0.10.0"):
liangjing's avatar
v1  
liangjing committed
1414
                self.transformer_engine_v_0_10 = True
liangjing's avatar
liangjing committed
1415
            if core.utils.is_te_min_version("0.11.0"):
liangjing's avatar
v1  
liangjing committed
1416
1417
                self.transformer_engine_v_0_11 = True

liangjing's avatar
liangjing committed
1418
1419
            assert not args.squared_relu, ("TransformerEngine does not support squared "
                                           "relu activation.")
liangjing's avatar
v1  
liangjing committed
1420
1421

        self.use_fp8 = args.fp8 is not None
1422
        self.fp8_recipe = None
1423
        self.fp8_group = None
1424
        if self.use_fp8:
liangjing's avatar
v1  
liangjing committed
1425
1426
            assert args.transformer_impl == 'transformer_engine', \
                'transformer-engine required for fp8 training and inference'
liangjing's avatar
liangjing committed
1427
            self.fp8_group = mpu.get_amax_reduction_group(tp_only_amax_red=config.tp_only_amax_red)
liangjing's avatar
v1  
liangjing committed
1428
            if args.fp8 == "e4m3":
1429
                fp8_format = transformer_engine.common.recipe.Format.E4M3
liangjing's avatar
v1  
liangjing committed
1430
            elif args.fp8 == "hybrid":
1431
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
liangjing's avatar
v1  
liangjing committed
1432
1433
            else:
                raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
            self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
                margin=args.fp8_margin,
                interval=args.fp8_interval,
                fp8_format=fp8_format,
                amax_history_len=args.fp8_amax_history_len,
                amax_compute_algo=args.fp8_amax_compute_algo,
                override_linear_precision=(False, False, not args.fp8_wgrad),
            )

        self.num_microbatches_in_previous_step = -1
        self.microbatch_count = 0
liangjing's avatar
v1  
liangjing committed
1445
        self.checkpoint_core_attention = config.recompute_granularity == 'selective'
1446

1447
        # Number of layers.
liangjing's avatar
v1  
liangjing committed
1448
1449
1450
1451
1452
1453
        self.num_layers = _get_num_layers(args, model_type,
                                          layer_type==LayerType.decoder)

        self.drop_path_rates = [
            rate.item() for rate in
            torch.linspace(0, self.drop_path_rate, config.num_layers)]
Mohammad's avatar
Mohammad committed
1454

liangjing's avatar
v1  
liangjing committed
1455
1456
1457
1458
1459
1460
1461
        self.retro_layer_numbers = None
        if model_type == ModelType.retro_decoder:
            retro_layer_start = 6 if config.num_layers <= 15 else 9
            self.retro_layer_numbers = \
                np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
        if model_type == ModelType.retro_encoder:
            self.retro_layer_numbers = [1]
1462

Mohammad's avatar
Mohammad committed
1463
        # Transformer layers.
liangjing's avatar
v1  
liangjing committed
1464
1465
1466
1467
1468
        if args.retro_add_retriever:
            assert self.recompute_granularity != 'full', \
                "Full recompute not supported for Retro."
            assert args.transformer_impl == 'local', \
                "Transformer engine does not support Retro layers."
Mohammad's avatar
Mohammad committed
1469
        def build_layer(layer_number):
1470
            if args.transformer_impl == 'local':
liangjing's avatar
v1  
liangjing committed
1471
1472
1473
                current_layer_type = _get_layer_type(
                    model_type, layer_type, self.retro_layer_numbers,
                    layer_number)
1474
                return ParallelTransformerLayer(
liangjing's avatar
v1  
liangjing committed
1475
                    config,
1476
                    layer_number,
liangjing's avatar
v1  
liangjing committed
1477
                    layer_type=current_layer_type,
1478
1479
1480
                    self_attn_mask_type=self_attn_mask_type,
                    drop_path_rate=self.drop_path_rates[layer_number - 1])
            else:
liangjing's avatar
v1  
liangjing committed
1481
1482
1483
1484
1485
1486
1487
1488
                # This argument is only available from TE v0.10 onwards.
                extra_transformer_engine_kwargs = {}
                if self.transformer_engine_v_0_8:
                    extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
                if self.transformer_engine_v_0_10:
                    extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
                if self.transformer_engine_v_0_11:
                    extra_transformer_engine_kwargs["normalization"] = args.normalization
liangjing's avatar
liangjing committed
1489
1490
1491
1492
1493
                assert config.attention_softmax_in_fp32, "TransformerEngine only supports softmax compute in FP32."
                assert (
                    (bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))) and args.fp16) == config.apply_query_key_layer_scaling
                ), ("Unsupported config for apply_query_key_layer_scaling in TransformerEngine. If --apply-query-key-layer-scaling is "
                    "provided, set env-var NVTE_APPLY_QK_LAYER_SCALING=1 and you must be using fp16.")
1494
                return transformer_engine.pytorch.TransformerLayer(
liangjing's avatar
v1  
liangjing committed
1495
1496
1497
1498
1499
1500
1501
1502
                    config.hidden_size,
                    config.ffn_hidden_size,
                    config.num_attention_heads,
                    layernorm_epsilon=config.layernorm_epsilon,
                    hidden_dropout=config.hidden_dropout,
                    attention_dropout=config.attention_dropout,
                    init_method=config.init_method,
                    output_layer_init_method=config.output_layer_init_method,
1503
                    layer_number=layer_number,
liangjing's avatar
v1  
liangjing committed
1504
                    kv_channels=config.kv_channels,
1505
                    self_attn_mask_type=self_attn_mask_type.name,
liangjing's avatar
liangjing committed
1506
1507
1508
1509
1510
                    tp_group=mpu.get_tensor_model_parallel_group() if mpu.is_initialized() else None,
                    tp_size=mpu.get_tensor_model_parallel_world_size(),
                    get_rng_state_tracker=get_cuda_rng_tracker
                    if get_cuda_rng_tracker().is_initialized()
                    else None,
liangjing's avatar
v1  
liangjing committed
1511
                    fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
1512
1513
                    seq_length=args.seq_length,
                    micro_batch_size=args.micro_batch_size,
liangjing's avatar
v1  
liangjing committed
1514
1515
1516
                    sequence_parallel=config.sequence_parallel,
                    params_dtype=config.params_dtype,
                    apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
1517
1518
1519
1520
                    output_layernorm=False,
                    layer_type="encoder",
                    drop_path_rate=self.drop_path_rates[layer_number - 1],
                    set_parallel_mode=True,
liangjing's avatar
v1  
liangjing committed
1521
1522
                    fuse_qkv_params=True,
                    **extra_transformer_engine_kwargs)
1523

liangjing's avatar
v1  
liangjing committed
1524
1525
        if config.virtual_pipeline_model_parallel_size is not None:
            assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
1526
1527
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1528
            assert args.model_type != ModelType.encoder_and_decoder
1529
1530
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
liangjing's avatar
v1  
liangjing committed
1531
            self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
1532
1533
1534
1535
1536
1537
1538
1539
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
1540
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
liangjing's avatar
v1  
liangjing committed
1541
                config.num_layers // config.virtual_pipeline_model_parallel_size) + \
1542
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
1543
        else:
1544
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1545
            if args.model_type == ModelType.encoder_and_decoder and \
1546
1547
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1548
1549
1550
1551
1552
1553
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
1554
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
1555

1556
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
1557
            # When a standalone embedding stage is used (e.g.,
1558
            # args.standalone_embedding_stage == True), virtual pipeline ranks
1559
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
1560
1561
1562
1563
1564
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
1565
1566
1567
1568
1569
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
1570

liangjing's avatar
v1  
liangjing committed
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
            # Update dropout rate for Retro encoder.
            if model_type == ModelType.retro_encoder:
                for layer in self.layers:
                    if layer.self_attention.use_flash_attn:
                        layer.self_attention.core_attention_flash.dropout_p = \
                            torch.nn.Dropout(args.retro_encoder_attention_dropout)
                    else:
                        layer.self_attention.core_attention.attention_dropout.p =\
                            args.retro_encoder_attention_dropout
                    layer.hidden_dropout = args.retro_encoder_hidden_dropout

liangjing's avatar
liangjing committed
1582
        if self.post_process and self.post_norm:
1583
            # Final layer norm before output.
liangjing's avatar
liangjing committed
1584
            self.final_norm = get_norm(config)
1585

Mohammad's avatar
Mohammad committed
1586
    def _get_layer(self, layer_number):
1587
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
1588

1589
    def _checkpointed_forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1590
1591
                              encoder_output, enc_dec_attn_mask,
                              rotary_pos_emb, is_first_microbatch):
1592
        """Forward method with activation checkpointing."""
liangjing's avatar
v1  
liangjing committed
1593
        def custom(start, end):
1594
            def custom_forward(*args, **kwargs):
1595
                x_, *args = args
Mohammad's avatar
Mohammad committed
1596
1597
                for index in range(start, end):
                    layer = self._get_layer(index)
1598
                    x_ = layer(x_, *args, **kwargs)
1599
                return x_
liangjing's avatar
v1  
liangjing committed
1600
1601
1602
1603
1604
1605
1606
            return custom_forward

        te_forward_kwargs = {}
        if self.transformer_impl == 'transformer_engine':
            te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
            if self.transformer_engine_v_0_10:
                te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
1607

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1608
        if self.recompute_method == 'uniform':
liangjing's avatar
v1  
liangjing committed
1609
1610
            # Uniformly divide the total number of Transformer layers and
            # checkpoint the input activation of each divided chunk.
1611
1612
1613
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
1614
                if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1615
1616
                    hidden_states = transformer_engine.pytorch.checkpoint(
                        custom(l, l + self.recompute_num_layers),
1617
1618
1619
                        self.distribute_saved_activations,
                        tensor_parallel.get_cuda_rng_tracker,
                        mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1620
                        hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1621
                        enc_dec_attn_mask, **te_forward_kwargs)
1622
1623
1624
1625
                else:
                    hidden_states = tensor_parallel.checkpoint(
                        custom(l, l + self.recompute_num_layers),
                        self.distribute_saved_activations,
liangjing's avatar
v1  
liangjing committed
1626
1627
1628
                        hidden_states, attention_mask,
                        encoder_output, enc_dec_attn_mask,
                        None, None, None, None, rotary_pos_emb)
1629

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1630
                l += self.recompute_num_layers
1631

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1632
        elif self.recompute_method == 'block':
1633
1634
1635
1636
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1637
                if l < self.recompute_num_layers:
1638
                    if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1639
1640
                        hidden_states = transformer_engine.pytorch.checkpoint(
                            custom(l, l + 1),
1641
1642
1643
                            self.distribute_saved_activations,
                            tensor_parallel.get_cuda_rng_tracker,
                            mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1644
                            hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1645
                            enc_dec_attn_mask, **te_forward_kwargs)
1646
1647
1648
1649
                    else:
                        hidden_states = tensor_parallel.checkpoint(
                            custom(l, l + 1),
                            self.distribute_saved_activations,
liangjing's avatar
v1  
liangjing committed
1650
1651
1652
                            hidden_states, attention_mask,
                            encoder_output, enc_dec_attn_mask,
                            None, None, None, None, rotary_pos_emb)
1653
                else:
1654
                    if self.transformer_impl == 'transformer_engine':
liangjing's avatar
v1  
liangjing committed
1655
                        hidden_states = custom(l, l + 1)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1656
                            hidden_states, attention_mask, encoder_output,
liangjing's avatar
v1  
liangjing committed
1657
                            enc_dec_attn_mask, **te_forward_kwargs)
1658
1659
                    else:
                        hidden_states = custom(l, l + 1)(
liangjing's avatar
v1  
liangjing committed
1660
1661
1662
                            hidden_states, attention_mask,
                            encoder_output, enc_dec_attn_mask,
                            None, None, None, None, rotary_pos_emb)
1663
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1664
            raise ValueError("Invalid activation recompute method.")
1665
1666
1667

        return hidden_states

1668
    def set_input_tensor(self, input_tensor):
1669
1670
1671
1672
1673
1674
1675
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
1676
1677
        self.input_tensor = input_tensor

1678
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1679
                encoder_output=None, enc_dec_attn_mask=None,
liangjing's avatar
v1  
liangjing committed
1680
1681
1682
1683
1684
                retriever_input=None,
                retriever_output=None,
                retriever_attn_mask=None,
                inference_params=None,
                rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1685
1686
        # hidden_states: [s, b, h]

1687
        # Checks.
mshoeybi's avatar
mshoeybi committed
1688
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1689
            assert self.recompute_granularity is None, \
1690
                'inference does not work with activation checkpointing'
1691

1692
        if not self.pre_process:
1693
            # See set_input_tensor()
1694
            hidden_states = self.input_tensor
1695

1696
1697
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
1698
1699
1700
1701
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
1702
1703
1704
1705
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
1706
        #
1707
1708
1709
1710
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
1711
        hidden_states = core.utils.make_viewless_tensor(
1712
            hidden_states,
1713
1714
            requires_grad=True,
            keep_graph=True,
1715
1716
        )

liangjing's avatar
v1  
liangjing committed
1717
        # RNG context.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1718
        if self.sequence_parallel:
1719
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
1720
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1721
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1722

liangjing's avatar
v1  
liangjing committed
1723
        # Forward layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1724
        with rng_context:
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
            # The fp8_autocast context manager is a no-op when enabled=True
            # The if...else serves to short circuit name resolution for fp8_autocast
            with transformer_engine.pytorch.fp8_autocast(
                enabled=self.use_fp8,
                fp8_recipe=self.fp8_recipe,
                fp8_group=self.fp8_group
            ) if self.use_fp8 else nullcontext():
                # Determine if the current iteration is first microbatch
                if self.num_microbatches_in_previous_step != get_num_microbatches():
                    self.microbatch_count = 0 # Reset count on new batch size rampup interval
                self.num_microbatches_in_previous_step = get_num_microbatches()
                is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

                # Forward pass.
                if self.recompute_granularity == 'full':
                    hidden_states = self._checkpointed_forward(hidden_states,
                                                               attention_mask,
                                                               encoder_output,
                                                               enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1744
                                                               rotary_pos_emb,
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
                                                               is_first_microbatch)
                else:
                    forward_kwargs = {
                        'encoder_output': encoder_output,
                        'enc_dec_attn_mask': enc_dec_attn_mask,
                        'inference_params': inference_params,
                    }

                    if self.transformer_impl == 'transformer_engine':
                        forward_kwargs['is_first_microbatch'] = is_first_microbatch
                        forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
liangjing's avatar
v1  
liangjing committed
1756
1757
1758
1759
1760
1761
1762
                        if self.transformer_engine_v_0_10:
                            forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
                    else:
                        forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
                        forward_kwargs['retriever_input'] = retriever_input
                        forward_kwargs['retriever_output'] = retriever_output
                        forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
1763
1764
1765
1766
1767
1768
1769
1770
1771

                    for index in range(self.num_layers):
                        layer = self._get_layer(index)

                        hidden_states = layer(
                            hidden_states,
                            attention_mask,
                            **forward_kwargs)

liangjing's avatar
v1  
liangjing committed
1772
1773
1774
1775
1776
1777
1778
1779
                        # First Retro decoder layer returns both hidden_states
                        # and retriever_output. Make retriever_output available
                        # to subsequence Retro layers.
                        if isinstance(hidden_states, tuple):
                            assert len(hidden_states) == 2
                            hidden_states, retriever_output = hidden_states
                            forward_kwargs["retriever_output"] = retriever_output

1780
1781
1782
                # Skip counter update for eval and activation checkpointing
                if torch.is_grad_enabled() and self.training:
                    self.microbatch_count += 1
mshoeybi's avatar
mshoeybi committed
1783

1784
        # Final layer norm.
liangjing's avatar
liangjing committed
1785
1786
        if self.post_process and self.post_norm:
            hidden_states = self.final_norm(hidden_states)
1787

1788
        return hidden_states
liangjing's avatar
liangjing committed
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803

    def load_state_dict(self, state_dict, strict=True):
        """Customize load."""

        # Handle renaming layernorm -> norm in component names
        state_dict_ = {}
        for key in state_dict.keys():
            # Bypass TransformerEngine module parameters.
            if "layernorm_qkv" in key or "layernorm_mlp" in key:
                state_dict_[key] = state_dict[key]
                continue
            newkey = key.replace("layernorm", "norm")
            state_dict_[newkey] = state_dict[key]

        super().load_state_dict(state_dict_, strict)