transformer.py 30.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39
40


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
46
47
48
49
50
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
51
    state back into h hidden dimension.
52
53
    """

54
    def __init__(self, init_method, output_layer_init_method):
55
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
56
        args = get_args()
57
58
59

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
60
            args.hidden_size,
61
            args.ffn_hidden_size,
62
            gather_output=False,
63
64
            init_method=init_method,
            skip_bias_add=True)
65

66
67
68
69
70
71
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
72
73
74

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
75
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
76
            args.hidden_size,
77
            input_is_parallel=True,
78
79
            init_method=output_layer_init_method,
            skip_bias_add=True)
80

81
82
    def forward(self, hidden_states):

83
84
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
85

86
87
88
89
90
91
92
93
94
95
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
96
97


rprenger's avatar
rprenger committed
98
99
100
101
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
102
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
103
104
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
105
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
106
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
107
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
108
109
110
111
112
113
114
115
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
         
    def forward(self, hidden_states):
        # hidden_states: [b, s, h]
        b = hidden_states.size(0)
        s = hidden_states.size(1)
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
116
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
117
118
119
        max_prob, max_ind = torch.max(route, dim=2)
        max_prob = torch.unsqueeze(max_prob, 2)
        
rprenger's avatar
rprenger committed
120
121
122
        # TODO (rprenger) TODO this could be made easier to read
        # Converting [b, s, h] to [b*s, h].
        # Each vector could be routed differently 
rprenger's avatar
rprenger committed
123
124
125
126
127
128
        hidden_states = hidden_states.permute(2,0,1).view(hidden_states.size(2), -1).permute(1,0).unsqueeze(1)
        max_prob = max_prob.permute(2,0,1).view(max_prob.size(2), -1).permute(1,0).unsqueeze(1)
        max_ind = max_ind.view(-1)

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
129
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
rprenger's avatar
rprenger committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        for expert_num, expert in enumerate(self.experts):
            ind = (max_ind==expert_num).nonzero().unsqueeze(2).repeat(1,1, h)
            hidden = torch.gather(hidden_states, 0, ind)
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
            output_total.scatter_(0, ind, output) 
            output_bias_total.scatter_(0, ind, output_bias) 
        
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
        output_total = output_total.permute(2,0,1).view(h, b, s).permute(1,2,0)
        output_bias_total = output_bias_total.permute(2,0,1).view(h, b, s).permute(1,2,0)

        return output_total, output_bias_total

145
class ParallelAttention(MegatronModule):
146
147
148
149
150
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
151

152
    def __init__(self, init_method,
153
154
155
156
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
157
        args = get_args()
Mohammad's avatar
Mohammad committed
158
        self.fp16 = args.fp16
159
        self.bf16 = args.bf16
160

Mohammad's avatar
Mohammad committed
161
162
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
163
164
165
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
166
167
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
168
        self.params_dtype = args.params_dtype
169
170

        projection_size = args.kv_channels * args.num_attention_heads
171
172

        # Per attention head and per partition values.
173
        world_size = mpu.get_tensor_model_parallel_world_size()
174
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
175
                                                    world_size)
176
        self.hidden_size_per_attention_head = mpu.divide(
177
            projection_size, args.num_attention_heads)
178
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
179
            args.num_attention_heads, world_size)
180
181

        # Strided linear layer.
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
201

202
203
204
205
206
207
208
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
209
            self.fp16, self.bf16,
210
211
            self.attn_mask_type,
            args.masked_softmax_fusion,
212
            attention_mask_func,
213
214
215
            self.attention_softmax_in_fp32,
            coeff)

216
217
218
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
219
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
220
221
222

        # Output.
        self.dense = mpu.RowParallelLinear(
223
            projection_size,
Mohammad's avatar
Mohammad committed
224
            args.hidden_size,
225
            input_is_parallel=True,
226
227
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
228

229
230
231
232
233
234
235
236
237
238
239
240

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
241
                encoder_output=None, inference_params=None):
242
        # hidden_states: [sq, b, h]
243

244
245
246
247

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
248
        if inference_params:
249
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
250
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
251
                inf_max_batch_size = inference_params.max_batch_size
252
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
253
                    inf_max_seq_len, inf_max_batch_size)
254
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
255
                    inf_max_seq_len, inf_max_batch_size)
256
257
258
259
260
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
261

262

263
264
265
        # =====================
        # Query, Key, and Value
        # =====================
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
302
303


mshoeybi's avatar
mshoeybi committed
304
305
306
        # ==================================
        # Adjust key and value for inference
        # ==================================
307

mshoeybi's avatar
mshoeybi committed
308
        if inference_params:
mshoeybi's avatar
mshoeybi committed
309
310
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
311
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
312
313
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
314
            assert sequence_end <= inference_key_memory.size(0)
315
            # Copy key and values.
316
317
318
319
320
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
321
                :sequence_end, batch_start:batch_end, ...]
322
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
323
                :sequence_end, batch_start:batch_end, ...]
324

325

326
327
328
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
329

330
        # [b, np, sq, sk]
331
332
333
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
334
                       key_layer.size(0))
335

336
        # [sq, b, np, hn] -> [sq, b * np, hn]
337
338
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
339
        # [sk, b, np, hn] -> [sk, b * np, hn]
340
341
342
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

343
        # preallocting result tensor: [b * np, sq, sk]
344
        matmul_result = torch.empty(
345
346
            output_size[0]*output_size[1],
            output_size[2],
347
            output_size[3],
348
            dtype=query_layer.dtype,
349
350
            device=torch.cuda.current_device())

351
        # Raw attention scores. [b * np, sq, sk]
352
353
        matmul_result = torch.baddbmm(
            matmul_result,
354
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
355
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
356
357
            beta=0.0, alpha=(1.0/self.norm_factor))

358
        # change view to [b, np, sq, sk]
359
360
        attention_scores = matmul_result.view(*output_size)

361

362
363
364
        # ===========================
        # Attention probs and dropout
        # ===========================
365

366
        # attention scores and attention mask [b, np, sq, sk]
367
368
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
369

370
371
372
373
374
375
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
376
        # Context layer. [sq, b, hp]
377
378
        # =========================

379
380
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
381

382
        # context layer shape: [b, np, sq, hn]
383
384
385
386
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
387

388
        # change view [sk, b * np, hn]
389
        value_layer = value_layer.view(value_layer.size(0),
390
                                       output_size[0] * output_size[1], -1)
391

392
        # change view [b * np, sq, sk]
393
394
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
395

396
        # matmul: [b * np, sq, hn]
397
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
398

399
        # change view [b, np, sq, hn]
400
401
        context_layer = context_layer.view(*output_size)

402
        # [b, np, sq, hn] --> [sq, b, np, hn]
403
404
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

405
        # [sq, b, np, hn] --> [sq, b, hp]
406
407
408
409
410
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
411
        # Output. [sq, b, h]
412
413
414
        # =================

        output, bias = self.dense(context_layer)
415

416
417
418
        return output, bias


419
def bias_dropout_add(x, bias, residual, prob, training):
420
421
422
423
424
425
426
427
428
429
430
431
432
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
433
434
435
436
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
437
438
439
440
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
441
442
443
444
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
445
    return bias_dropout_add(x, bias, residual, prob, False)
446
447
448
449
450


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

451
    Transformer layer takes input with size [b, s, h] and returns an
452
453
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
454

455
456
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
457
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
458
        args = get_args()
459
460

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
461
        self.layer_number = layer_number
462
        self.layer_type = layer_type
463
464

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
465
            = args.apply_residual_connection_post_layernorm
466

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
467
468
469
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

470
471
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
472
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
473
474
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
475
476

        # Self attention.
477
478
479
480
481
482
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
483
484
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
485

486
        # Layernorm on the attention output
487
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
488
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
489
490
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
491

492
493
494
495
496
497
498
499
500
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
501
502
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
503

504
        # MLP
rprenger's avatar
rprenger committed
505
506
507
508
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
509

510
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
511
512
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
513
514
        # hidden_states: [b, s, h]

515
        # Layer norm at the beginning of the transformer layer.
516
517
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
518
        attention_output, attention_bias = \
519
520
521
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
522
                inference_params=inference_params)
523

524
525
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
526
527
528
529
            residual = layernorm_output
        else:
            residual = hidden_states

530
531
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
532
533
534
535
536
537
538
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
539
        else:
540
541
            bias_dropout_add_func = get_bias_dropout_add(self.training)

542
        # re-enable torch grad to enable fused optimization.
543
544
545
546
547
548
549
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

550
551
552
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

575
        # MLP.
576
        mlp_output, mlp_bias = self.mlp(layernorm_output)
577

578
579
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
580
            residual = layernorm_output
581
        else:
582
            residual = layernorm_input
rprenger's avatar
rprenger committed
583
        
584
        # re-enable torch grad to enable fused optimization.
585
586
587
588
589
590
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
591
592
593
594
595
596
597

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

598
    def __init__(self, init_method, output_layer_init_method,
599
                 layer_type=LayerType.encoder,
600
601
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
602
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
603
        args = get_args()
604

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
605
        self.bf16 = args.bf16
606
        self.fp32_residual_connection = args.fp32_residual_connection
607
608
609
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
610

611
        # Store activation checkpoiting flag.
612
613
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
614
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
615

616
        # Number of layers.
617
618
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
619
620
621

        # Transformer layers.
        def build_layer(layer_number):
622
            return ParallelTransformerLayer(
623
624
625
                init_method,
                output_layer_init_method,
                layer_number,
626
627
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
628
629
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
630
631
632
633
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
634
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
635
636
637
638
639
640
641
642
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
643
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
644
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
645
646
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
647
            # Each stage gets a contiguous set of layers.
648
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
649

650
        self.layers = torch.nn.ModuleList(
651
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
652

653
        if self.post_process:
654
655
656
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
657
658
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
659

Mohammad's avatar
Mohammad committed
660
    def _get_layer(self, layer_number):
661
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
662

663
664
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
665
666
667
668
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
669
670
671
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
672
673
                for index in range(start, end):
                    layer = self._get_layer(index)
674
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
675
676
677
                return x_
            return custom_forward

678
679
680
681
682
683
684
685
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
686
                    self.distribute_checkpointed_activations,
687
688
689
690
691
692
693
694
695
696
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
697
                        self.distribute_checkpointed_activations,
698
699
700
701
702
703
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
704
705
706

        return hidden_states

707
    def set_input_tensor(self, input_tensor):
708
709
710
711
712
713
714
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
715
716
        self.input_tensor = input_tensor

717
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
718
719
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
720

721
        # Checks.
mshoeybi's avatar
mshoeybi committed
722
        if inference_params:
723
            assert self.activations_checkpoint_method is None, \
724
                'inference does not work with activation checkpointing'
725

726
        if self.pre_process:
727
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
728
            # If the input flag for fp32 residual connection is set, convert for float.
729
730
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
731
            # Otherwise, leave it as is.
732
733
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
734
        else:
735
            # See set_input_tensor()
736
            hidden_states = self.input_tensor
737

Vijay Korthikanti's avatar
Vijay Korthikanti committed
738
739
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
740

741
        if self.activations_checkpoint_method is not None:
742
            hidden_states = self._checkpointed_forward(hidden_states,
743
744
745
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
746
        else:
Mohammad's avatar
Mohammad committed
747
748
            for index in range(self.num_layers):
                layer = self._get_layer(index)
749
750
751
752
753
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
754
755
                    inference_params=inference_params)

756

757
        # Final layer norm.
758
        if self.post_process:
759
760
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
761
762
763
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
764
        
765
        return output