transformer.py 38.1 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
18
from contextlib import nullcontext
19
import torch
20
import torch.nn.functional as F
21

22
from megatron import get_timers, get_args
23
from megatron import mpu
24
from .module import MegatronModule
25
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
26
from megatron.model import LayerNorm
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
29
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
30

31

32
33
34
35
36
37
38
39
40
41
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
42
    Transformer takes input of size [s, b, h] and returns a
43
44
45
46
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

47
48
49
50
51
class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
    def __init__(self, drop_prob=0.):
53
54
55
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
    def forward(self, hidden_state):
57
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
            return hidden_state
59
60
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
62
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
64
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
        output = hidden_state.div(keep_prob) * random_tensor
66
67
68
        return output


69
70
71
72
73
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
74
    state back into h hidden dimension.
75
76
    """

77
    def __init__(self, init_method, output_layer_init_method):
78
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
79
        args = get_args()
80
81
82

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            args.ffn_hidden_size,
85
            gather_output=False,
86
87
            init_method=init_method,
            skip_bias_add=True)
88

89
90
91
92
93
94
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
95
96
97

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
98
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
99
            args.hidden_size,
100
            input_is_parallel=True,
101
102
            init_method=output_layer_init_method,
            skip_bias_add=True)
103

104
105
    def forward(self, hidden_states):

106
107
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
108

109
110
111
112
113
114
115
116
117
118
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
119

rprenger's avatar
rprenger committed
120
121
122
123
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
124
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
125
126
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
127
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
128
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
129
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
130
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
131

rprenger's avatar
rprenger committed
132
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
133
134
135
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
136
137
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
138
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
139
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
140
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
141

rprenger's avatar
rprenger committed
142
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
143
        # Converting [s, b, h] to [s*b, h].
144
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
146
147
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
148
149
150

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
151
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
152
        
rprenger's avatar
rprenger committed
153
        for expert_num, expert in enumerate(self.experts):
154
155
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
156
157
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
158
159
160
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
161
162
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
163
164
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
165
166

        return output_total, output_bias_total
167

168
169

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
170

171
172
173
174
175
176
177
178
179
180
181
182
183
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
        self.sequence_parallel = args.sequence_parallel
185
186
187
188
189
190
191
192
193

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
        world_size = mpu.get_tensor_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(projection_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            projection_size, args.num_attention_heads)
194
195
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
215

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
236
        # preallocting input tensor: [b * np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
237
238
239
240
241
242
        matmul_input_buffer = torch.empty(
            output_size[0]*output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device())
243
244
245

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
246
            matmul_input_buffer,
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
265
        if not self.sequence_parallel:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            with mpu.get_cuda_rng_tracker().fork():
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


309
class ParallelAttention(MegatronModule):
310
311
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
312
    Self-attention layer takes input with size [s, b, h]
313
314
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
315

316
    def __init__(self, init_method,
317
318
319
320
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
321
        args = get_args()
322
        self.layer_number = max(1, layer_number)
323
324
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
325
        self.params_dtype = args.params_dtype
326
327

        projection_size = args.kv_channels * args.num_attention_heads
328
329

        # Per attention head and per partition values.
330
        world_size = mpu.get_tensor_model_parallel_world_size()
331
        self.hidden_size_per_attention_head = mpu.divide(
332
            projection_size, args.num_attention_heads)
333
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
334
            args.num_attention_heads, world_size)
335
336

        # Strided linear layer.
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
356

357
358
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
359
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
360
361
362

        # Output.
        self.dense = mpu.RowParallelLinear(
363
            projection_size,
Mohammad's avatar
Mohammad committed
364
            args.hidden_size,
365
            input_is_parallel=True,
366
367
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
368

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    def _checkpointed_attention_forward(self, query_layer, key_layer,
                                        value_layer, attention_mask):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

        hidden_states = mpu.checkpoint(
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask)

        return hidden_states
386
387
388
389
390
391
392
393
394
395
396

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
397
                encoder_output=None, inference_params=None):
398
        # hidden_states: [sq, b, h]
399

400
401
402
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
403
        if inference_params:
404
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
405
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
406
                inf_max_batch_size = inference_params.max_batch_size
407
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
408
                    inf_max_seq_len, inf_max_batch_size)
409
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
410
                    inf_max_seq_len, inf_max_batch_size)
411
412
413
414
415
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
416

417
418
419
        # =====================
        # Query, Key, and Value
        # =====================
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
456

mshoeybi's avatar
mshoeybi committed
457
458
459
        # ==================================
        # Adjust key and value for inference
        # ==================================
460

mshoeybi's avatar
mshoeybi committed
461
        if inference_params:
mshoeybi's avatar
mshoeybi committed
462
463
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
464
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
465
466
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
467
            assert sequence_end <= inference_key_memory.size(0)
468
            # Copy key and values.
469
470
471
472
473
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
474
                :sequence_end, batch_start:batch_end, ...]
475
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
476
                :sequence_end, batch_start:batch_end, ...]
477

478
479
480
        # ==================================
        # core attention computation
        # ==================================
481

Vijay Korthikanti's avatar
Vijay Korthikanti committed
482
        if self.checkpoint_core_attention:
483
484
            context_layer = self._checkpointed_attention_forward(
                query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
485
        else:
486
487
            context_layer = self.core_attention(
                query_layer, key_layer, value_layer, attention_mask)
488
489

        # =================
490
        # Output. [sq, b, h]
491
492
493
        # =================

        output, bias = self.dense(context_layer)
494

495
496
497
        return output, bias


498
def bias_dropout_add(x, bias, residual, prob, training):
499
500
501
502
503
504
505
506
507
508
509
510
511
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
512
513
514
515
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
516
517
518
519
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
520
521
522
523
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
524
    return bias_dropout_add(x, bias, residual, prob, False)
525
526
527
528
529


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
530
    Transformer layer takes input with size [s, b, h] and returns an
531
532
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
533

534
535
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
536
537
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
538
        args = get_args()
539
540

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
541
        self.layer_number = layer_number
542
        self.layer_type = layer_type
543
544

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
545
            = args.apply_residual_connection_post_layernorm
546

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
547
548
549
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

550
551
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
552
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
553
            eps=args.layernorm_epsilon,
554
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
555
            sequence_parallel=args.sequence_parallel)
556
557

        # Self attention.
558
559
560
561
562
563
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
564
565
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
566
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
567

568
        # Layernorm on the attention output
569
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
570
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
571
            eps=args.layernorm_epsilon,
572
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
573
            sequence_parallel=args.sequence_parallel)
574

575
576
577
578
579
580
581
582
583
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
584
                eps=args.layernorm_epsilon,
585
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
586
                sequence_parallel=args.sequence_parallel)
587

588
        # MLP
rprenger's avatar
rprenger committed
589
590
591
592
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
593

594
595
596
597
598
599
600
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

601
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
602
603
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
604
        # hidden_states: [s, b, h]
605

606
        # Layer norm at the beginning of the transformer layer.
607
608
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
609
        attention_output, attention_bias = \
610
611
612
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
613
                inference_params=inference_params)
614

615
616
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
617
618
619
620
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
621
        if self.drop_path is None:
622
623
624
625
626
627
628
629
630
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
631
            else:
632
                bias_dropout_add_func = get_bias_dropout_add(self.training)
633

634
            with self.bias_dropout_add_exec_handler():
635
636
637
638
639
640
641
642
643
644
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
645

646
647
648
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

649
650
651
652
653
654
655
656
657
658
659
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

660
            with self.bias_dropout_add_exec_handler():
661
662
663
664
665
666
667
668
669
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

670
        # MLP.
671
        mlp_output, mlp_bias = self.mlp(layernorm_output)
672

673
674
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
675
            residual = layernorm_output
676
        else:
677
678
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
679
        if self.drop_path is None:
680
            with self.bias_dropout_add_exec_handler():
681
682
683
684
685
686
687
688
689
690
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
691
692
693
694

        return output


695
696
697
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
698
    The sole purpose of this layer is for when a standalone embedding layer
699
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
700
701
702
703
704
705
706
707
708
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
709
710
711
712
713
714
715
716
717
718
719
720
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


721
722
723
class ParallelTransformer(MegatronModule):
    """Transformer class."""

724
    def __init__(self, init_method, output_layer_init_method,
725
                 layer_type=LayerType.encoder,
726
                 self_attn_mask_type=AttnMaskType.padding,
727
                 post_layer_norm=True, 
728
729
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
730
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
731
        args = get_args()
732

733
734
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
735
        self.bf16 = args.bf16
736
        self.fp32_residual_connection = args.fp32_residual_connection
737
        self.post_layer_norm = post_layer_norm
738
739
740
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
741
        self.drop_path_rate = drop_path_rate
742

743
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
744
745
746
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
747
748
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
749

Vijay Korthikanti's avatar
Vijay Korthikanti committed
750
        self.sequence_parallel = args.sequence_parallel
751

752
        # Number of layers.
753
754
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
755

Vijay Korthikanti's avatar
Vijay Korthikanti committed
756
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
757

Mohammad's avatar
Mohammad committed
758
759
        # Transformer layers.
        def build_layer(layer_number):
760
            return ParallelTransformerLayer(
761
762
763
                init_method,
                output_layer_init_method,
                layer_number,
764
                layer_type=layer_type,
765
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
766
                drop_path_rate=self.drop_path_rates[layer_number - 1])
767
768
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
769
770
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
771
            assert args.model_type != ModelType.encoder_and_decoder
772
773
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
774
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
775
776
777
778
779
780
781
782
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
783
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
784
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
785
786
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
787
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
788
789
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
790
791
792
793
794
795
796
797
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
798

799
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
800
            # When a standalone embedding stage is used (e.g.,
801
            # args.standalone_embedding_stage == True), virtual pipeline ranks
802
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
803
804
805
806
807
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
808
809
810
811
812
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
813

814
        if self.post_process and self.post_layer_norm:
815
816
817
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
818
                eps=args.layernorm_epsilon,
819
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
820
                sequence_parallel=args.sequence_parallel)
821

Mohammad's avatar
Mohammad committed
822
    def _get_layer(self, layer_number):
823
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
824

825
826
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
827
828
829
830
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
831
832
833
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
834
835
                for index in range(start, end):
                    layer = self._get_layer(index)
836
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
837
838
839
                return x_
            return custom_forward

Vijay Korthikanti's avatar
Vijay Korthikanti committed
840
        if self.recompute_method == 'uniform':
841
842
843
844
845
846
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
847
                    custom(l, l + self.recompute_num_layers),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
848
                    self.distribute_saved_activations,
849
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
850
                l += self.recompute_num_layers
851

Vijay Korthikanti's avatar
Vijay Korthikanti committed
852
        elif self.recompute_method == 'block':
853
854
855
856
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
857
                if l < self.recompute_num_layers:
858
859
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
860
                        self.distribute_saved_activations,
861
862
863
864
865
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
866
            raise ValueError("Invalid activation recompute method.")
867
868
869

        return hidden_states

870
    def set_input_tensor(self, input_tensor):
871
872
873
874
875
876
877
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
878
879
        self.input_tensor = input_tensor

880
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
881
882
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
883
884
        # hidden_states: [s, b, h]

885
        # Checks.
mshoeybi's avatar
mshoeybi committed
886
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
887
            assert self.recompute_granularity is None, \
888
                'inference does not work with activation checkpointing'
889

890
        if not self.pre_process:
891
            # See set_input_tensor()
892
            hidden_states = self.input_tensor
893

894
895
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
896
897
898
899
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
900
901
902
903
904
905
906
907
908
909
910
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
911
912
            requires_grad=True,
            keep_graph=True,
913
914
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
915
916
        if self.sequence_parallel:
            rng_context = mpu.get_cuda_rng_tracker().fork()
917
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
918
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
919
920

        with rng_context:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
921
            # Forward pass.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
922
            if self.recompute_granularity == 'full':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
923
924
925
926
927
928
929
930
931
932
933
934
935
                hidden_states = self._checkpointed_forward(hidden_states,
                                                           attention_mask,
                                                           encoder_output,
                                                           enc_dec_attn_mask)
            else:
                for index in range(self.num_layers):
                    layer = self._get_layer(index)
                    hidden_states = layer(
                        hidden_states,
                        attention_mask,
                        encoder_output=encoder_output,
                        enc_dec_attn_mask=enc_dec_attn_mask,
                        inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
936

937
        # Final layer norm.
938
        if self.post_process and self.post_layer_norm:
939
940
            hidden_states = self.final_layernorm(hidden_states)

941
        return hidden_states