transformer.py 35.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
18
from contextlib import nullcontext
19
import torch
20
import torch.nn.functional as F
21

Mohammad's avatar
Mohammad committed
22
from megatron import get_args
23
from megatron import mpu
24
from .module import MegatronModule
25
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
26
from megatron.model import LayerNorm
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
29
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
30
31
32
33
34
35
36
37
38
39
40

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

46
47
48
49
50
51

class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
    def __init__(self, drop_prob=0.):
53
54
55
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
    def forward(self, hidden_state):
57
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
            return hidden_state
59
60
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
62
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
64
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
        output = hidden_state.div(keep_prob) * random_tensor
66
67
68
        return output


69
70
71
72
73
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
74
    state back into h hidden dimension.
75
76
    """

77
    def __init__(self, init_method, output_layer_init_method):
78
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
79
        args = get_args()
80
81
82

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            args.ffn_hidden_size,
85
            gather_output=False,
86
87
            init_method=init_method,
            skip_bias_add=True)
88

89
90
91
92
93
94
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
95
96
97

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
98
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
99
            args.hidden_size,
100
            input_is_parallel=True,
101
102
            init_method=output_layer_init_method,
            skip_bias_add=True)
103

104
105
    def forward(self, hidden_states):

106
107
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
108

109
110
111
112
113
114
115
116
117
118
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
119

rprenger's avatar
rprenger committed
120
121
122
123
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
124
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
125
126
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
127
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
128
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
129
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
130
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
131

rprenger's avatar
rprenger committed
132
133
134
135
136
137
    def forward(self, hidden_states):
        # hidden_states: [b, s, h]
        b = hidden_states.size(0)
        s = hidden_states.size(1)
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
138
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
139
        max_prob, max_ind = torch.max(route, dim=2)
140
141
        max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]

rprenger's avatar
rprenger committed
142
143
        # TODO (rprenger) TODO this could be made easier to read
        # Converting [b, s, h] to [b*s, h].
144
145
146
147
        # Each vector could be routed differently
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
        max_ind = max_ind.view(-1) # [b*s]
rprenger's avatar
rprenger committed
148
149
150

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
151
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
152
        
rprenger's avatar
rprenger committed
153
        for expert_num, expert in enumerate(self.experts):
154
155
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
156
157
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
158
159
160
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
161
162
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
163
164
        output_total = output_total.view(b, s, h)
        output_bias_total = output_bias_total.view(b, s, h)
rprenger's avatar
rprenger committed
165
166
167

        return output_total, output_bias_total

168
class ParallelAttention(MegatronModule):
169
170
171
172
173
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
174

175
    def __init__(self, init_method,
176
177
178
179
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
180
        args = get_args()
Mohammad's avatar
Mohammad committed
181
        self.fp16 = args.fp16
182
        self.bf16 = args.bf16
183

Mohammad's avatar
Mohammad committed
184
185
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
186
187
188
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
189
190
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
191
        self.params_dtype = args.params_dtype
192
193

        projection_size = args.kv_channels * args.num_attention_heads
194
195

        # Per attention head and per partition values.
196
        world_size = mpu.get_tensor_model_parallel_world_size()
197
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
198
                                                    world_size)
199
        self.hidden_size_per_attention_head = mpu.divide(
200
            projection_size, args.num_attention_heads)
201
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
202
            args.num_attention_heads, world_size)
203
204

        # Strided linear layer.
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
224

225
226
227
228
229
230
231
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
232
            self.fp16, self.bf16,
233
234
            self.attn_mask_type,
            args.masked_softmax_fusion,
235
            attention_mask_func,
236
237
238
            self.attention_softmax_in_fp32,
            coeff)

239
240
241
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
242
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
243
244
245

        # Output.
        self.dense = mpu.RowParallelLinear(
246
            projection_size,
Mohammad's avatar
Mohammad committed
247
            args.hidden_size,
248
            input_is_parallel=True,
249
250
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
251

252
253
254
255
256
257
258
259
260
261
262
263

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
264
                encoder_output=None, inference_params=None):
265
        # hidden_states: [sq, b, h]
266

267
268
269
270

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
271
        if inference_params:
272
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
273
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
274
                inf_max_batch_size = inference_params.max_batch_size
275
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
276
                    inf_max_seq_len, inf_max_batch_size)
277
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
278
                    inf_max_seq_len, inf_max_batch_size)
279
280
281
282
283
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
284

285

286
287
288
        # =====================
        # Query, Key, and Value
        # =====================
289

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
325
326


mshoeybi's avatar
mshoeybi committed
327
328
329
        # ==================================
        # Adjust key and value for inference
        # ==================================
330

mshoeybi's avatar
mshoeybi committed
331
        if inference_params:
mshoeybi's avatar
mshoeybi committed
332
333
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
334
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
335
336
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
337
            assert sequence_end <= inference_key_memory.size(0)
338
            # Copy key and values.
339
340
341
342
343
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
344
                :sequence_end, batch_start:batch_end, ...]
345
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
346
                :sequence_end, batch_start:batch_end, ...]
347

348

349
350
351
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
352

353
        # [b, np, sq, sk]
354
355
356
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
357
                       key_layer.size(0))
358

359
        # [sq, b, np, hn] -> [sq, b * np, hn]
360
361
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
362
        # [sk, b, np, hn] -> [sk, b * np, hn]
363
364
365
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

366
        # preallocting result tensor: [b * np, sq, sk]
367
        matmul_result = torch.empty(
368
369
            output_size[0]*output_size[1],
            output_size[2],
370
            output_size[3],
371
            dtype=query_layer.dtype,
372
373
            device=torch.cuda.current_device())

374
        # Raw attention scores. [b * np, sq, sk]
375
376
        matmul_result = torch.baddbmm(
            matmul_result,
377
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
378
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
379
380
            beta=0.0, alpha=(1.0/self.norm_factor))

381
        # change view to [b, np, sq, sk]
382
383
        attention_scores = matmul_result.view(*output_size)

384

385
386
387
        # ===========================
        # Attention probs and dropout
        # ===========================
388

389
        # attention scores and attention mask [b, np, sq, sk]
390
391
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
392

393
394
395
396
397
398
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
399
        # Context layer. [sq, b, hp]
400
401
        # =========================

402
403
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
404

405
        # context layer shape: [b, np, sq, hn]
406
407
408
409
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
410

411
        # change view [sk, b * np, hn]
412
        value_layer = value_layer.view(value_layer.size(0),
413
                                       output_size[0] * output_size[1], -1)
414

415
        # change view [b * np, sq, sk]
416
417
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
418

419
        # matmul: [b * np, sq, hn]
420
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
421

422
        # change view [b, np, sq, hn]
423
424
        context_layer = context_layer.view(*output_size)

425
        # [b, np, sq, hn] --> [sq, b, np, hn]
426
427
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

428
        # [sq, b, np, hn] --> [sq, b, hp]
429
430
431
432
433
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
434
        # Output. [sq, b, h]
435
436
437
        # =================

        output, bias = self.dense(context_layer)
438

439
440
441
        return output, bias


442
def bias_dropout_add(x, bias, residual, prob, training):
443
444
445
446
447
448
449
450
451
452
453
454
455
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
456
457
458
459
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
460
461
462
463
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
464
465
466
467
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
468
    return bias_dropout_add(x, bias, residual, prob, False)
469
470
471
472
473


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

474
    Transformer layer takes input with size [b, s, h] and returns an
475
476
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
477

478
479
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
480
481
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
482
        args = get_args()
483
484

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
485
        self.layer_number = layer_number
486
        self.layer_type = layer_type
487
488

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
489
            = args.apply_residual_connection_post_layernorm
490

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
491
492
493
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

494
495
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
496
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
497
498
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
499
500

        # Self attention.
501
502
503
504
505
506
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
507
508
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
509
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
510

511
        # Layernorm on the attention output
512
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
513
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
514
515
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
516

517
518
519
520
521
522
523
524
525
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
526
527
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
528

529
        # MLP
rprenger's avatar
rprenger committed
530
531
532
533
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
534

535
536
537
538
539
540
541
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

542
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
543
544
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
545
546
        # hidden_states: [b, s, h]

547
        # Layer norm at the beginning of the transformer layer.
548
549
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
550
        attention_output, attention_bias = \
551
552
553
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
554
                inference_params=inference_params)
555

556
557
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
558
559
560
561
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
562
        if self.drop_path is None:
563
564
565
566
567
568
569
570
571
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
572
            else:
573
                bias_dropout_add_func = get_bias_dropout_add(self.training)
574

575
576
577
578
579
580
            with self.bias_dropout_add_exec_handler():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
581
582
583
584
585
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
586

587
588
589
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

590
591
592
593
594
595
596
597
598
599
600
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

601
602
603
604
605
606
            with self.bias_dropout_add_exec_handler():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
607
608
609
610

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

611
        # MLP.
612
        mlp_output, mlp_bias = self.mlp(layernorm_output)
613

614
615
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
616
            residual = layernorm_output
617
        else:
618
619
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
620
        if self.drop_path is None:
621
622
623
624
625
626
            with self.bias_dropout_add_exec_handler():
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
627
628
629
630
631
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
632
633
634
635

        return output


636
637
638
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
639
    The sole purpose of this layer is for when a standalone embedding layer
640
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
641
642
643
644
645
646
647
648
649
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
650
651
652
653
654
655
656
657
658
659
660
661
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


662
663
664
class ParallelTransformer(MegatronModule):
    """Transformer class."""

665
    def __init__(self, init_method, output_layer_init_method,
666
                 layer_type=LayerType.encoder,
667
                 self_attn_mask_type=AttnMaskType.padding,
668
669
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
670
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
671
        args = get_args()
672

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
673
        self.bf16 = args.bf16
674
        self.fp32_residual_connection = args.fp32_residual_connection
675
676
677
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
678
        self.drop_path_rate = drop_path_rate
679

680
        # Store activation checkpoiting flag.
681
682
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
683
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
684

685
        # Number of layers.
686
687
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
688

Vijay Korthikanti's avatar
Vijay Korthikanti committed
689
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
690

Mohammad's avatar
Mohammad committed
691
692
        # Transformer layers.
        def build_layer(layer_number):
693
            return ParallelTransformerLayer(
694
695
696
                init_method,
                output_layer_init_method,
                layer_number,
697
                layer_type=layer_type,
698
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
699
                drop_path_rate=self.drop_path_rates[layer_number - 1])
700
701
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
702
703
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
704
            assert args.model_type != ModelType.encoder_and_decoder
705
706
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
707
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
708
709
710
711
712
713
714
715
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
716
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
717
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
718
719
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
720
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
721
722
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
723
724
725
726
727
728
729
730
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
731

732
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
733
            # When a standalone embedding stage is used (e.g.,
734
            # args.standalone_embedding_stage == True), virtual pipeline ranks
735
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
736
737
738
739
740
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
741
742
743
744
745
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
746

747
        if self.post_process:
748
749
750
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
751
752
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
753

Mohammad's avatar
Mohammad committed
754
    def _get_layer(self, layer_number):
755
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
756

757
758
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
759
760
761
762
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
763
764
765
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
766
767
                for index in range(start, end):
                    layer = self._get_layer(index)
768
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
769
770
771
                return x_
            return custom_forward

772
773
774
775
776
777
778
779
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
780
                    self.distribute_checkpointed_activations,
781
782
783
784
785
786
787
788
789
790
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
791
                        self.distribute_checkpointed_activations,
792
793
794
795
796
797
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
798
799
800

        return hidden_states

801
    def set_input_tensor(self, input_tensor):
802
803
804
805
806
807
808
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
809
810
        self.input_tensor = input_tensor

811
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
812
813
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
814

815
        # Checks.
mshoeybi's avatar
mshoeybi committed
816
        if inference_params:
817
            assert self.activations_checkpoint_method is None, \
818
                'inference does not work with activation checkpointing'
819

820
        if self.pre_process:
821
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
822
            # If the input flag for fp32 residual connection is set, convert for float.
823
824
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
825
            # Otherwise, leave it as is.
826
827
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
828
        else:
829
            # See set_input_tensor()
830
            hidden_states = self.input_tensor
831

832
833
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
834
835
836
837
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        # Transpose encoder output.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
854
        if encoder_output is not None:
855
            encoder_output = encoder_output.transpose(0, 1).contiguous()
856

857
        # Forward pass.
858
        if self.activations_checkpoint_method is not None:
859
            hidden_states = self._checkpointed_forward(hidden_states,
860
861
862
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
863
        else:
Mohammad's avatar
Mohammad committed
864
865
            for index in range(self.num_layers):
                layer = self._get_layer(index)
866
867
868
869
870
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
871
872
                    inference_params=inference_params)

873

874
        # Final layer norm.
875
        if self.post_process:
876
877
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
878
879
880
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
881

882
        return output