transformer.py 38.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
18
from contextlib import nullcontext
19
import torch
20
import torch.nn.functional as F
21

22
from megatron import get_timers, get_args
23
from megatron import mpu
24
from .module import MegatronModule
25
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
26
from megatron.model import LayerNorm
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
29
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
30

31

32
33
34
35
36
37
38
39
40
41
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
42
    Transformer takes input of size [s, b, h] and returns a
43
44
45
46
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

47
48
49
50
51
class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
    def __init__(self, drop_prob=0.):
53
54
55
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
    def forward(self, hidden_state):
57
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
            return hidden_state
59
60
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
62
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
64
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
        output = hidden_state.div(keep_prob) * random_tensor
66
67
68
        return output


69
70
71
72
73
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
74
    state back into h hidden dimension.
75
76
    """

77
    def __init__(self, init_method, output_layer_init_method):
78
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
79
        args = get_args()
80
81
82

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            args.ffn_hidden_size,
85
            gather_output=False,
86
87
            init_method=init_method,
            skip_bias_add=True)
88

89
90
91
92
93
94
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
95
96
97

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
98
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
99
            args.hidden_size,
100
            input_is_parallel=True,
101
102
            init_method=output_layer_init_method,
            skip_bias_add=True)
103

104
105
    def forward(self, hidden_states):

106
107
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
108

109
110
111
112
113
114
115
116
117
118
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
119

rprenger's avatar
rprenger committed
120
121
122
123
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
124
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
125
126
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
127
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
128
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
129
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
130
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
131

rprenger's avatar
rprenger committed
132
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
133
134
135
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
136
137
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
138
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
139
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
140
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
141

rprenger's avatar
rprenger committed
142
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
143
        # Converting [s, b, h] to [s*b, h].
144
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
146
147
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
148
149
150

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
151
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
152
        
rprenger's avatar
rprenger committed
153
        for expert_num, expert in enumerate(self.experts):
154
155
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
156
157
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
158
159
160
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
161
162
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
163
164
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
165
166

        return output_total, output_bias_total
167

168
169

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
170
    matmul_input_buffer = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
171

172
173
174
175
176
177
178
179
180
181
182
183
184
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
185
        self.sequence_parallel = args.sequence_parallel
186
187
188
189
190
191
192
193
194

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
        world_size = mpu.get_tensor_model_parallel_world_size()
        self.hidden_size_per_partition = mpu.divide(projection_size,
                                                    world_size)
        self.hidden_size_per_attention_head = mpu.divide(
            projection_size, args.num_attention_heads)
195
196
        self.num_attention_heads_per_partition = mpu.divide(
            args.num_attention_heads, world_size)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
237
        # preallocting input tensor: [b * np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
239
        if CoreAttention.matmul_input_buffer is None:
            CoreAttention.matmul_input_buffer = torch.empty(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
240
241
242
243
244
                output_size[0]*output_size[1],
                output_size[2],
                output_size[3],
                dtype=query_layer.dtype,
                device=torch.cuda.current_device())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
245
246
247
248
        else:
            assert CoreAttention.matmul_input_buffer.size() == \
                    (output_size[0]*output_size[1], output_size[2], output_size[3]), \
                "buffer dimensions should remain the same during the training run"
249
250
251

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
252
            CoreAttention.matmul_input_buffer,
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
271
        if not self.sequence_parallel:
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            with mpu.get_cuda_rng_tracker().fork():
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


315
class ParallelAttention(MegatronModule):
316
317
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
318
    Self-attention layer takes input with size [s, b, h]
319
320
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
321

322
    def __init__(self, init_method,
323
324
325
326
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
327
        args = get_args()
328
        self.layer_number = max(1, layer_number)
329
330
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
331
        self.params_dtype = args.params_dtype
332
333

        projection_size = args.kv_channels * args.num_attention_heads
334
335

        # Per attention head and per partition values.
336
        world_size = mpu.get_tensor_model_parallel_world_size()
337
        self.hidden_size_per_attention_head = mpu.divide(
338
            projection_size, args.num_attention_heads)
339
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
340
            args.num_attention_heads, world_size)
341
342

        # Strided linear layer.
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
362

363
364
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
365
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
366
367
368

        # Output.
        self.dense = mpu.RowParallelLinear(
369
            projection_size,
Mohammad's avatar
Mohammad committed
370
            args.hidden_size,
371
            input_is_parallel=True,
372
373
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
374

375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    def _checkpointed_attention_forward(self, query_layer, key_layer,
                                        value_layer, attention_mask):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

        hidden_states = mpu.checkpoint(
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask)

        return hidden_states
392
393
394
395
396
397
398
399
400
401
402

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
403
                encoder_output=None, inference_params=None):
404
        # hidden_states: [sq, b, h]
405

406
407
408
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
409
        if inference_params:
410
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
411
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
412
                inf_max_batch_size = inference_params.max_batch_size
413
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
414
                    inf_max_seq_len, inf_max_batch_size)
415
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
416
                    inf_max_seq_len, inf_max_batch_size)
417
418
419
420
421
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
422

423
424
425
        # =====================
        # Query, Key, and Value
        # =====================
426

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
462

mshoeybi's avatar
mshoeybi committed
463
464
465
        # ==================================
        # Adjust key and value for inference
        # ==================================
466

mshoeybi's avatar
mshoeybi committed
467
        if inference_params:
mshoeybi's avatar
mshoeybi committed
468
469
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
470
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
471
472
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
473
            assert sequence_end <= inference_key_memory.size(0)
474
            # Copy key and values.
475
476
477
478
479
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
480
                :sequence_end, batch_start:batch_end, ...]
481
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
482
                :sequence_end, batch_start:batch_end, ...]
483

484
485
486
        # ==================================
        # core attention computation
        # ==================================
487

Vijay Korthikanti's avatar
Vijay Korthikanti committed
488
        if self.checkpoint_core_attention:
489
490
            context_layer = self._checkpointed_attention_forward(
                query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
491
        else:
492
493
            context_layer = self.core_attention(
                query_layer, key_layer, value_layer, attention_mask)
494
495

        # =================
496
        # Output. [sq, b, h]
497
498
499
        # =================

        output, bias = self.dense(context_layer)
500

501
502
503
        return output, bias


504
def bias_dropout_add(x, bias, residual, prob, training):
505
506
507
508
509
510
511
512
513
514
515
516
517
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
518
519
520
521
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
522
523
524
525
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
526
527
528
529
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
530
    return bias_dropout_add(x, bias, residual, prob, False)
531
532
533
534
535


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
536
    Transformer layer takes input with size [s, b, h] and returns an
537
538
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
539

540
541
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
542
543
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
544
        args = get_args()
545
546

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
547
        self.layer_number = layer_number
548
        self.layer_type = layer_type
549
550

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
551
            = args.apply_residual_connection_post_layernorm
552

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
553
554
555
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

556
557
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
558
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
559
            eps=args.layernorm_epsilon,
560
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
561
            sequence_parallel=args.sequence_parallel)
562
563

        # Self attention.
564
565
566
567
568
569
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
570
571
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
572
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
573

574
        # Layernorm on the attention output
575
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
576
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
577
            eps=args.layernorm_epsilon,
578
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
579
            sequence_parallel=args.sequence_parallel)
580

581
582
583
584
585
586
587
588
589
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
590
                eps=args.layernorm_epsilon,
591
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
592
                sequence_parallel=args.sequence_parallel)
593

594
        # MLP
rprenger's avatar
rprenger committed
595
596
597
598
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
599

600
601
602
603
604
605
606
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

607
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
608
609
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
610
        # hidden_states: [s, b, h]
611

612
        # Layer norm at the beginning of the transformer layer.
613
614
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
615
        attention_output, attention_bias = \
616
617
618
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
619
                inference_params=inference_params)
620

621
622
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
623
624
625
626
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
627
        if self.drop_path is None:
628
629
630
631
632
633
634
635
636
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
637
            else:
638
                bias_dropout_add_func = get_bias_dropout_add(self.training)
639

640
            with self.bias_dropout_add_exec_handler():
641
642
643
644
645
646
647
648
649
650
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
651

652
653
654
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

655
656
657
658
659
660
661
662
663
664
665
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

666
            with self.bias_dropout_add_exec_handler():
667
668
669
670
671
672
673
674
675
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

676
        # MLP.
677
        mlp_output, mlp_bias = self.mlp(layernorm_output)
678

679
680
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
681
            residual = layernorm_output
682
        else:
683
684
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
685
        if self.drop_path is None:
686
            with self.bias_dropout_add_exec_handler():
687
688
689
690
691
692
693
694
695
696
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
697
698
699
700

        return output


701
702
703
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
704
    The sole purpose of this layer is for when a standalone embedding layer
705
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
706
707
708
709
710
711
712
713
714
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
715
716
717
718
719
720
721
722
723
724
725
726
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


727
728
729
class ParallelTransformer(MegatronModule):
    """Transformer class."""

730
    def __init__(self, init_method, output_layer_init_method,
731
                 layer_type=LayerType.encoder,
732
                 self_attn_mask_type=AttnMaskType.padding,
733
                 post_layer_norm=True, 
734
735
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
736
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
737
        args = get_args()
738

739
740
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
741
        self.bf16 = args.bf16
742
        self.fp32_residual_connection = args.fp32_residual_connection
743
        self.post_layer_norm = post_layer_norm
744
745
746
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
747
        self.drop_path_rate = drop_path_rate
748

749
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
750
751
752
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
753
754
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
755

Vijay Korthikanti's avatar
Vijay Korthikanti committed
756
        self.sequence_parallel = args.sequence_parallel
757

758
        # Number of layers.
759
760
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
761

Vijay Korthikanti's avatar
Vijay Korthikanti committed
762
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
763

Mohammad's avatar
Mohammad committed
764
765
        # Transformer layers.
        def build_layer(layer_number):
766
            return ParallelTransformerLayer(
767
768
769
                init_method,
                output_layer_init_method,
                layer_number,
770
                layer_type=layer_type,
771
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
772
                drop_path_rate=self.drop_path_rates[layer_number - 1])
773
774
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
775
776
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
777
            assert args.model_type != ModelType.encoder_and_decoder
778
779
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
780
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
781
782
783
784
785
786
787
788
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
789
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
790
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
791
792
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
793
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
794
795
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
796
797
798
799
800
801
802
803
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
804

805
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
806
            # When a standalone embedding stage is used (e.g.,
807
            # args.standalone_embedding_stage == True), virtual pipeline ranks
808
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
809
810
811
812
813
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
814
815
816
817
818
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
819

820
        if self.post_process and self.post_layer_norm:
821
822
823
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
824
                eps=args.layernorm_epsilon,
825
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
826
                sequence_parallel=args.sequence_parallel)
827

Mohammad's avatar
Mohammad committed
828
    def _get_layer(self, layer_number):
829
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
830

831
832
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
833
834
835
836
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
837
838
839
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
840
841
                for index in range(start, end):
                    layer = self._get_layer(index)
842
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
843
844
845
                return x_
            return custom_forward

Vijay Korthikanti's avatar
Vijay Korthikanti committed
846
        if self.recompute_method == 'uniform':
847
848
849
850
851
852
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
853
                    custom(l, l + self.recompute_num_layers),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
854
                    self.distribute_saved_activations,
855
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
856
                l += self.recompute_num_layers
857

Vijay Korthikanti's avatar
Vijay Korthikanti committed
858
        elif self.recompute_method == 'block':
859
860
861
862
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
863
                if l < self.recompute_num_layers:
864
865
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
866
                        self.distribute_saved_activations,
867
868
869
870
871
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
872
            raise ValueError("Invalid activation recompute method.")
873
874
875

        return hidden_states

876
    def set_input_tensor(self, input_tensor):
877
878
879
880
881
882
883
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
884
885
        self.input_tensor = input_tensor

886
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
887
888
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
889
890
        # hidden_states: [s, b, h]

891
        # Checks.
mshoeybi's avatar
mshoeybi committed
892
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
893
            assert self.recompute_granularity is None, \
894
                'inference does not work with activation checkpointing'
895

896
        if not self.pre_process:
897
            # See set_input_tensor()
898
            hidden_states = self.input_tensor
899

900
901
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
902
903
904
905
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
906
907
908
909
910
911
912
913
914
915
916
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
917
918
            requires_grad=True,
            keep_graph=True,
919
920
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
921
922
        if self.sequence_parallel:
            rng_context = mpu.get_cuda_rng_tracker().fork()
923
        else:
924
            rng_context = nullcontext
Vijay Korthikanti's avatar
Vijay Korthikanti committed
925
926

        with rng_context:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
927
            # Forward pass.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
928
            if self.recompute_granularity == 'full':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
929
930
931
932
933
934
935
936
937
938
939
940
941
                hidden_states = self._checkpointed_forward(hidden_states,
                                                           attention_mask,
                                                           encoder_output,
                                                           enc_dec_attn_mask)
            else:
                for index in range(self.num_layers):
                    layer = self._get_layer(index)
                    hidden_states = layer(
                        hidden_states,
                        attention_mask,
                        encoder_output=encoder_output,
                        enc_dec_attn_mask=enc_dec_attn_mask,
                        inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
942

943
        # Final layer norm.
944
        if self.post_process and self.post_layer_norm:
945
946
            hidden_states = self.final_layernorm(hidden_states)

947
        return hidden_states