transformer.py 38.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Transformer."""
import math
5
from contextlib import nullcontext
6
import torch
7
import torch.nn.functional as F
8

9
from megatron import get_timers, get_args, get_global_memory_buffer
10
from megatron import mpu
11
from .module import MegatronModule
12
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
13
from megatron.model import LayerNorm
14
15
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
16
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
17

18

19
20
21
22
23
24
25
26
27
28
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
29
    Transformer takes input of size [s, b, h] and returns a
30
31
32
33
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

34
35
36
37
38
class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
39
    def __init__(self, drop_prob=0.):
40
41
42
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
43
    def forward(self, hidden_state):
44
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
45
            return hidden_state
46
47
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
49
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
51
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
        output = hidden_state.div(keep_prob) * random_tensor
53
54
55
        return output


56
57
58
59
60
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
61
    state back into h hidden dimension.
62
63
    """

64
    def __init__(self, init_method, output_layer_init_method):
65
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
66
        args = get_args()
67
68
69

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
70
            args.hidden_size,
71
            args.ffn_hidden_size,
72
            gather_output=False,
73
74
            init_method=init_method,
            skip_bias_add=True)
75

76
77
78
79
80
81
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
82
83
84

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
85
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
86
            args.hidden_size,
87
            input_is_parallel=True,
88
89
            init_method=output_layer_init_method,
            skip_bias_add=True)
90

91
92
    def forward(self, hidden_states):

93
94
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
95

96
97
98
99
100
101
102
103
104
105
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
106

rprenger's avatar
rprenger committed
107
108
109
110
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
111
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
112
113
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
114
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
115
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
116
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
117
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
118

rprenger's avatar
rprenger committed
119
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
120
121
122
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
123
124
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
125
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
126
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
127
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
128

rprenger's avatar
rprenger committed
129
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
130
        # Converting [s, b, h] to [s*b, h].
131
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
132
133
134
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
135
136
137

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
138
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
139
        
rprenger's avatar
rprenger committed
140
        for expert_num, expert in enumerate(self.experts):
141
142
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
143
144
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
145
146
147
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
148
149
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
150
151
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
152
153

        return output_total, output_bias_total
154

155
156

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
157

158
159
160
161
162
163
164
165
166
167
168
169
170
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
171
        self.sequence_parallel = args.sequence_parallel
172
173
174
175
176
177
178
179
180

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
        world_size = mpu.get_tensor_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(projection_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            projection_size, args.num_attention_heads)
181
182
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
223
        # preallocting input tensor: [b * np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
224
        matmul_input_buffer = get_global_memory_buffer().get_tensor(
225
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
226
            query_layer.dtype, "mpu")
227
228
229

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
230
            matmul_input_buffer,
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
249
        if not self.sequence_parallel:
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            with mpu.get_cuda_rng_tracker().fork():
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


293
class ParallelAttention(MegatronModule):
294
295
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
296
    Self-attention layer takes input with size [s, b, h]
297
298
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
299

300
    def __init__(self, init_method,
301
302
303
304
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
305
        args = get_args()
306
        self.layer_number = max(1, layer_number)
307
308
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
309
        self.params_dtype = args.params_dtype
310
311

        projection_size = args.kv_channels * args.num_attention_heads
312
313

        # Per attention head and per partition values.
314
        world_size = mpu.get_tensor_model_parallel_world_size()
315
        self.hidden_size_per_attention_head = mpu.divide(
316
            projection_size, args.num_attention_heads)
317
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
318
            args.num_attention_heads, world_size)
319
320

        # Strided linear layer.
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
340

341
342
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
343
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
344
345
346

        # Output.
        self.dense = mpu.RowParallelLinear(
347
            projection_size,
Mohammad's avatar
Mohammad committed
348
            args.hidden_size,
349
            input_is_parallel=True,
350
351
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
352

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    def _checkpointed_attention_forward(self, query_layer, key_layer,
                                        value_layer, attention_mask):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

        hidden_states = mpu.checkpoint(
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask)

        return hidden_states
370
371
372
373
374
375
376
377
378
379
380

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
381
                encoder_output=None, inference_params=None):
382
        # hidden_states: [sq, b, h]
383

384
385
386
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
387
        if inference_params:
388
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
389
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
390
                inf_max_batch_size = inference_params.max_batch_size
391
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
392
                    inf_max_seq_len, inf_max_batch_size)
393
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
394
                    inf_max_seq_len, inf_max_batch_size)
395
396
397
398
399
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
400

401
402
403
        # =====================
        # Query, Key, and Value
        # =====================
404

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
440

mshoeybi's avatar
mshoeybi committed
441
442
443
        # ==================================
        # Adjust key and value for inference
        # ==================================
444

mshoeybi's avatar
mshoeybi committed
445
        if inference_params:
mshoeybi's avatar
mshoeybi committed
446
447
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
448
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
449
450
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
451
            assert sequence_end <= inference_key_memory.size(0)
452
            # Copy key and values.
453
454
455
456
457
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
458
                :sequence_end, batch_start:batch_end, ...]
459
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
460
                :sequence_end, batch_start:batch_end, ...]
461

462
463
464
        # ==================================
        # core attention computation
        # ==================================
465

Vijay Korthikanti's avatar
Vijay Korthikanti committed
466
        if self.checkpoint_core_attention:
467
468
            context_layer = self._checkpointed_attention_forward(
                query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
469
        else:
470
471
            context_layer = self.core_attention(
                query_layer, key_layer, value_layer, attention_mask)
472
473

        # =================
474
        # Output. [sq, b, h]
475
476
477
        # =================

        output, bias = self.dense(context_layer)
478

479
480
481
        return output, bias


482
def bias_dropout_add(x, bias, residual, prob, training):
483
484
485
486
487
488
489
490
491
492
493
494
495
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
496
497
498
499
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
500
501
502
503
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
504
505
506
507
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
508
    return bias_dropout_add(x, bias, residual, prob, False)
509
510
511
512
513


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
514
    Transformer layer takes input with size [s, b, h] and returns an
515
516
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
517

518
519
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
520
521
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
522
        args = get_args()
523
524

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
525
        self.layer_number = layer_number
526
        self.layer_type = layer_type
527
528

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
529
            = args.apply_residual_connection_post_layernorm
530

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
531
532
533
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

534
535
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
536
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
537
            eps=args.layernorm_epsilon,
538
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
539
            sequence_parallel=args.sequence_parallel)
540
541

        # Self attention.
542
543
544
545
546
547
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
548
549
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
550
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
551

552
        # Layernorm on the attention output
553
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
554
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
555
            eps=args.layernorm_epsilon,
556
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
557
            sequence_parallel=args.sequence_parallel)
558

559
560
561
562
563
564
565
566
567
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
568
                eps=args.layernorm_epsilon,
569
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
570
                sequence_parallel=args.sequence_parallel)
571

572
        # MLP
rprenger's avatar
rprenger committed
573
574
575
576
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
577

578
579
580
581
582
583
584
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

585
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
586
587
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
588
        # hidden_states: [s, b, h]
589

590
        # Layer norm at the beginning of the transformer layer.
591
592
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
593
        attention_output, attention_bias = \
594
595
596
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
597
                inference_params=inference_params)
598

599
600
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
601
602
603
604
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
605
        if self.drop_path is None:
606
607
608
609
610
611
612
613
614
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
615
            else:
616
                bias_dropout_add_func = get_bias_dropout_add(self.training)
617

618
            with self.bias_dropout_add_exec_handler():
619
620
621
622
623
624
625
626
627
628
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
629

630
631
632
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

633
634
635
636
637
638
639
640
641
642
643
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

644
            with self.bias_dropout_add_exec_handler():
645
646
647
648
649
650
651
652
653
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

654
        # MLP.
655
        mlp_output, mlp_bias = self.mlp(layernorm_output)
656

657
658
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
659
            residual = layernorm_output
660
        else:
661
662
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
663
        if self.drop_path is None:
664
            with self.bias_dropout_add_exec_handler():
665
666
667
668
669
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
670
671
672
673
674
675
676
677
678
679
680

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
            output = mpu.make_viewless_tensor(inp = output,
                                              requires_grad = output.requires_grad,
                                              keep_graph = True)

681
682
683
684
685
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
686
687
688
689

        return output


690
691
692
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
693
    The sole purpose of this layer is for when a standalone embedding layer
694
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
695
696
697
698
699
700
701
702
703
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
704
705
706
707
708
709
710
711
712
713
714
715
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


716
717
718
class ParallelTransformer(MegatronModule):
    """Transformer class."""

719
    def __init__(self, init_method, output_layer_init_method,
720
                 layer_type=LayerType.encoder,
721
                 self_attn_mask_type=AttnMaskType.padding,
722
                 post_layer_norm=True, 
723
724
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
725
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
726
        args = get_args()
727

728
729
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
730
        self.bf16 = args.bf16
731
        self.fp32_residual_connection = args.fp32_residual_connection
732
        self.post_layer_norm = post_layer_norm
733
734
735
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
736
        self.drop_path_rate = drop_path_rate
737

738
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
739
740
741
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
742
743
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
744

Vijay Korthikanti's avatar
Vijay Korthikanti committed
745
        self.sequence_parallel = args.sequence_parallel
746

747
        # Number of layers.
748
749
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
750

Vijay Korthikanti's avatar
Vijay Korthikanti committed
751
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
752

Mohammad's avatar
Mohammad committed
753
754
        # Transformer layers.
        def build_layer(layer_number):
755
            return ParallelTransformerLayer(
756
757
758
                init_method,
                output_layer_init_method,
                layer_number,
759
                layer_type=layer_type,
760
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
761
                drop_path_rate=self.drop_path_rates[layer_number - 1])
762
763
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
764
765
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
766
            assert args.model_type != ModelType.encoder_and_decoder
767
768
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
769
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
770
771
772
773
774
775
776
777
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
778
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
779
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
780
781
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
782
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
783
784
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
785
786
787
788
789
790
791
792
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
793

794
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
795
            # When a standalone embedding stage is used (e.g.,
796
            # args.standalone_embedding_stage == True), virtual pipeline ranks
797
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
798
799
800
801
802
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
803
804
805
806
807
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
808

809
        if self.post_process and self.post_layer_norm:
810
811
812
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
813
                eps=args.layernorm_epsilon,
814
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
815
                sequence_parallel=args.sequence_parallel)
816

Mohammad's avatar
Mohammad committed
817
    def _get_layer(self, layer_number):
818
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
819

820
821
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
822
823
824
825
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
826
827
828
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
829
830
                for index in range(start, end):
                    layer = self._get_layer(index)
831
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
832
833
834
                return x_
            return custom_forward

Vijay Korthikanti's avatar
Vijay Korthikanti committed
835
        if self.recompute_method == 'uniform':
836
837
838
839
840
841
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
842
                    custom(l, l + self.recompute_num_layers),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
843
                    self.distribute_saved_activations,
844
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
845
                l += self.recompute_num_layers
846

Vijay Korthikanti's avatar
Vijay Korthikanti committed
847
        elif self.recompute_method == 'block':
848
849
850
851
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
852
                if l < self.recompute_num_layers:
853
854
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
855
                        self.distribute_saved_activations,
856
857
858
859
860
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
861
            raise ValueError("Invalid activation recompute method.")
862
863
864

        return hidden_states

865
    def set_input_tensor(self, input_tensor):
866
867
868
869
870
871
872
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
873
874
        self.input_tensor = input_tensor

875
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
876
877
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
878
879
        # hidden_states: [s, b, h]

880
        # Checks.
mshoeybi's avatar
mshoeybi committed
881
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
882
            assert self.recompute_granularity is None, \
883
                'inference does not work with activation checkpointing'
884

885
        if not self.pre_process:
886
            # See set_input_tensor()
887
            hidden_states = self.input_tensor
888

889
890
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
891
892
893
894
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
895
896
897
898
899
900
901
902
903
904
905
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
906
907
            requires_grad=True,
            keep_graph=True,
908
909
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
910
911
        if self.sequence_parallel:
            rng_context = mpu.get_cuda_rng_tracker().fork()
912
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
913
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
914
915

        with rng_context:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
916
            # Forward pass.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
917
            if self.recompute_granularity == 'full':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
918
919
920
921
922
923
924
925
926
927
928
929
930
                hidden_states = self._checkpointed_forward(hidden_states,
                                                           attention_mask,
                                                           encoder_output,
                                                           enc_dec_attn_mask)
            else:
                for index in range(self.num_layers):
                    layer = self._get_layer(index)
                    hidden_states = layer(
                        hidden_states,
                        attention_mask,
                        encoder_output=encoder_output,
                        enc_dec_attn_mask=enc_dec_attn_mask,
                        inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
931

932
        # Final layer norm.
933
        if self.post_process and self.post_layer_norm:
934
935
            hidden_states = self.final_layernorm(hidden_states)

936
        return hidden_states