transformer.py 36.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
40
    Transformer takes input of size [s, b, h] and returns a
41
42
43
44
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

45
46
47
48
49
50

class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
    def __init__(self, drop_prob=0.):
52
53
54
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
55
    def forward(self, hidden_state):
56
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
57
            return hidden_state
58
59
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
60
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
61
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
62
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
63
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        output = hidden_state.div(keep_prob) * random_tensor
65
66
67
        return output


68
69
70
71
72
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
73
    state back into h hidden dimension.
74
75
    """

76
    def __init__(self, init_method, output_layer_init_method):
77
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
78
        args = get_args()
79
80
81

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
82
            args.hidden_size,
83
            args.ffn_hidden_size,
84
            gather_output=False,
85
86
            init_method=init_method,
            skip_bias_add=True)
87

88
89
90
91
92
93
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
94
95
96

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
97
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
98
            args.hidden_size,
99
            input_is_parallel=True,
100
101
            init_method=output_layer_init_method,
            skip_bias_add=True)
102

103
104
    def forward(self, hidden_states):

105
106
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
107

108
109
110
111
112
113
114
115
116
117
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
118

rprenger's avatar
rprenger committed
119
120
121
122
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
123
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
124
125
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
126
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
127
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
128
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
129
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
130

rprenger's avatar
rprenger committed
131
132
133
134
135
136
    def forward(self, hidden_states):
        # hidden_states: [b, s, h]
        b = hidden_states.size(0)
        s = hidden_states.size(1)
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
137
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
138
        max_prob, max_ind = torch.max(route, dim=2)
139
140
        max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]

rprenger's avatar
rprenger committed
141
142
        # TODO (rprenger) TODO this could be made easier to read
        # Converting [b, s, h] to [b*s, h].
143
144
145
146
        # Each vector could be routed differently
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
        max_ind = max_ind.view(-1) # [b*s]
rprenger's avatar
rprenger committed
147
148
149

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
150
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
151
        
rprenger's avatar
rprenger committed
152
        for expert_num, expert in enumerate(self.experts):
153
154
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
155
156
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
157
158
159
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
160
161
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
162
163
        output_total = output_total.view(b, s, h)
        output_bias_total = output_bias_total.view(b, s, h)
rprenger's avatar
rprenger committed
164
165

        return output_total, output_bias_total
166

167
class ParallelAttention(MegatronModule):
168
169
170
171
172
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
173

174
    def __init__(self, init_method,
175
176
177
178
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
179
        args = get_args()
Mohammad's avatar
Mohammad committed
180
        self.fp16 = args.fp16
181
        self.bf16 = args.bf16
182

Mohammad's avatar
Mohammad committed
183
184
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
185
186
187
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
188
189
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
190
        self.params_dtype = args.params_dtype
191
192

        projection_size = args.kv_channels * args.num_attention_heads
193
194

        # Per attention head and per partition values.
195
        world_size = mpu.get_tensor_model_parallel_world_size()
196
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
197
                                                    world_size)
198
        self.hidden_size_per_attention_head = mpu.divide(
199
            projection_size, args.num_attention_heads)
200
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
201
            args.num_attention_heads, world_size)
202
203

        # Strided linear layer.
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
223

224
225
226
227
228
229
230
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
231
            self.fp16, self.bf16,
232
233
            self.attn_mask_type,
            args.masked_softmax_fusion,
234
            attention_mask_func,
235
236
237
            self.attention_softmax_in_fp32,
            coeff)

238
239
240
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
241
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
242
243
244

        # Output.
        self.dense = mpu.RowParallelLinear(
245
            projection_size,
Mohammad's avatar
Mohammad committed
246
            args.hidden_size,
247
            input_is_parallel=True,
248
249
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
250

251
252
253
254
255
256
257
258
259
260
261
262

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
263
                encoder_output=None, inference_params=None):
264
        # hidden_states: [sq, b, h]
265

266
267
268
269

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
270
        if inference_params:
271
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
272
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
273
                inf_max_batch_size = inference_params.max_batch_size
274
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
275
                    inf_max_seq_len, inf_max_batch_size)
276
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
277
                    inf_max_seq_len, inf_max_batch_size)
278
279
280
281
282
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
283

284

285
286
287
        # =====================
        # Query, Key, and Value
        # =====================
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
324
325


mshoeybi's avatar
mshoeybi committed
326
327
328
        # ==================================
        # Adjust key and value for inference
        # ==================================
329

mshoeybi's avatar
mshoeybi committed
330
        if inference_params:
mshoeybi's avatar
mshoeybi committed
331
332
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
333
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
334
335
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
336
            assert sequence_end <= inference_key_memory.size(0)
337
            # Copy key and values.
338
339
340
341
342
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
343
                :sequence_end, batch_start:batch_end, ...]
344
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
345
                :sequence_end, batch_start:batch_end, ...]
346

347

348
349
350
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
351

352
        # [b, np, sq, sk]
353
354
355
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
356
                       key_layer.size(0))
357

358
        # [sq, b, np, hn] -> [sq, b * np, hn]
359
360
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
361
        # [sk, b, np, hn] -> [sk, b * np, hn]
362
363
364
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

365
        # preallocting result tensor: [b * np, sq, sk]
366
        matmul_result = torch.empty(
367
368
            output_size[0]*output_size[1],
            output_size[2],
369
            output_size[3],
370
            dtype=query_layer.dtype,
371
372
            device=torch.cuda.current_device())

373
        # Raw attention scores. [b * np, sq, sk]
374
375
        matmul_result = torch.baddbmm(
            matmul_result,
376
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
377
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
378
379
            beta=0.0, alpha=(1.0/self.norm_factor))

380
        # change view to [b, np, sq, sk]
381
382
        attention_scores = matmul_result.view(*output_size)

383

384
385
386
        # ===========================
        # Attention probs and dropout
        # ===========================
387

388
        # attention scores and attention mask [b, np, sq, sk]
389
390
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
391

392
393
394
395
396
397
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
398
        # Context layer. [sq, b, hp]
399
400
        # =========================

401
402
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
403

404
        # context layer shape: [b, np, sq, hn]
405
406
407
408
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
409

410
        # change view [sk, b * np, hn]
411
        value_layer = value_layer.view(value_layer.size(0),
412
                                       output_size[0] * output_size[1], -1)
413

414
        # change view [b * np, sq, sk]
415
416
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
417

418
        # matmul: [b * np, sq, hn]
419
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
420

421
        # change view [b, np, sq, hn]
422
423
        context_layer = context_layer.view(*output_size)

424
        # [b, np, sq, hn] --> [sq, b, np, hn]
425
426
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

427
        # [sq, b, np, hn] --> [sq, b, hp]
428
429
430
431
432
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
433
        # Output. [sq, b, h]
434
435
436
        # =================

        output, bias = self.dense(context_layer)
437

438
439
440
        return output, bias


441
def bias_dropout_add(x, bias, residual, prob, training):
442
443
444
445
446
447
448
449
450
451
452
453
454
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
455
456
457
458
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
459
460
461
462
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
463
464
465
466
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
467
    return bias_dropout_add(x, bias, residual, prob, False)
468
469
470
471
472


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

473
    Transformer layer takes input with size [b, s, h] and returns an
474
475
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
476

477
478
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
479
480
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
481
        args = get_args()
482
483

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
484
        self.layer_number = layer_number
485
        self.layer_type = layer_type
486
487

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
488
            = args.apply_residual_connection_post_layernorm
489

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
490
491
492
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

493
494
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
495
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
496
            eps=args.layernorm_epsilon,
497
498
            no_persist_layer_norm=args.no_persist_layer_norm,
            sequence_parallel=args.model_parallel_memory_opt)
499
500

        # Self attention.
501
502
503
504
505
506
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
507
508
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
509
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
510

511
        # Layernorm on the attention output
512
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
513
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
514
            eps=args.layernorm_epsilon,
515
516
            no_persist_layer_norm=args.no_persist_layer_norm,
            sequence_parallel=args.model_parallel_memory_opt)
517

518
519
520
521
522
523
524
525
526
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
527
                eps=args.layernorm_epsilon,
528
529
                no_persist_layer_norm=args.no_persist_layer_norm,
                sequence_parallel=args.model_parallel_memory_opt)
530

531
        # MLP
rprenger's avatar
rprenger committed
532
533
534
535
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
536

537
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
538
539
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
540
541
        # hidden_states: [b, s, h]

542
        # Layer norm at the beginning of the transformer layer.
543
544
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
545
        attention_output, attention_bias = \
546
547
548
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
549
                inference_params=inference_params)
550

551
552
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
553
554
555
556
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
557
        if self.drop_path is None:
558
559
560
561
562
563
564
565
566
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
567
            else:
568
                bias_dropout_add_func = get_bias_dropout_add(self.training)
569

570
571
572
573
574
575
576
577
578
579
580
581
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
582

583
584
585
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

608
        # MLP.
609
        mlp_output, mlp_bias = self.mlp(layernorm_output)
610

611
612
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
613
            residual = layernorm_output
614
        else:
615
616
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
617
        if self.drop_path is None:
618
619
620
621
622
623
624
625
626
627
628
629
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
630
631
632
633

        return output


634
635
636
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
637
    The sole purpose of this layer is for when a standalone embedding layer
638
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
639
640
641
642
643
644
645
646
647
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
648
649
650
651
652
653
654
655
656
657
658
659
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


660
661
662
class ParallelTransformer(MegatronModule):
    """Transformer class."""

663
    def __init__(self, init_method, output_layer_init_method,
664
                 layer_type=LayerType.encoder,
665
                 self_attn_mask_type=AttnMaskType.padding,
666
667
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
668
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
669
        args = get_args()
670

671
672
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
673
        self.bf16 = args.bf16
674
        self.fp32_residual_connection = args.fp32_residual_connection
675
676
677
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
678
        self.drop_path_rate = drop_path_rate
679

680
        # Store activation checkpoiting flag.
681
682
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
683
684
        self.distribute_checkpointed_activations = \
            args.distribute_checkpointed_activations and not args.model_parallel_memory_opt
685

686
687
        self.model_parallel_memory_opt = args.model_parallel_memory_opt

688
        # Number of layers.
689
690
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
691

Vijay Korthikanti's avatar
Vijay Korthikanti committed
692
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
693

Mohammad's avatar
Mohammad committed
694
695
        # Transformer layers.
        def build_layer(layer_number):
696
            return ParallelTransformerLayer(
697
698
699
                init_method,
                output_layer_init_method,
                layer_number,
700
                layer_type=layer_type,
701
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
702
                drop_path_rate=self.drop_path_rates[layer_number - 1])
703
704
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
705
706
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
707
            assert args.model_type != ModelType.encoder_and_decoder
708
709
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
710
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
711
712
713
714
715
716
717
718
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
719
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
720
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
721
722
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
723
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
724
725
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
726
727
728
729
730
731
732
733
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
734

735
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
736
            # When a standalone embedding stage is used (e.g.,
737
            # args.standalone_embedding_stage == True), virtual pipeline ranks
738
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
739
740
741
742
743
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
744
745
746
747
748
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
749

750
        if self.post_process:
751
752
753
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
754
                eps=args.layernorm_epsilon,
755
756
                no_persist_layer_norm=args.no_persist_layer_norm,
                sequence_parallel=args.model_parallel_memory_opt)
757

Mohammad's avatar
Mohammad committed
758
    def _get_layer(self, layer_number):
759
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
760

761
762
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
763
764
765
766
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
767
768
769
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
770
771
                for index in range(start, end):
                    layer = self._get_layer(index)
772
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
773
774
775
                return x_
            return custom_forward

776
777
778
779
780
781
782
783
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
784
                    self.distribute_checkpointed_activations,
785
786
787
788
789
790
791
792
793
794
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
795
                        self.distribute_checkpointed_activations,
796
797
798
799
800
801
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
802
803
804

        return hidden_states

805
    def set_input_tensor(self, input_tensor):
806
807
808
809
810
811
812
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
813
814
        self.input_tensor = input_tensor

815
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
816
817
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
818

819
        # Checks.
mshoeybi's avatar
mshoeybi committed
820
        if inference_params:
821
            assert self.activations_checkpoint_method is None, \
822
                'inference does not work with activation checkpointing'
823

824
        if self.pre_process:
825
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
826
            # If the input flag for fp32 residual connection is set, convert for float.
827
828
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
829
            # Otherwise, leave it as is.
830
831
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
832
833

            if self.model_parallel_memory_opt:
834
                hidden_states = mpu.scatter_to_sequence_parallel_region(hidden_states)
835

836
        else:
837
            # See set_input_tensor()
838
            hidden_states = self.input_tensor
839

840
841
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
842
843
844
845
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        # Transpose encoder output.
862
863
        if encoder_output is not None and \
                not self.model_parallel_memory_opt:
864
            encoder_output = encoder_output.transpose(0, 1).contiguous()
865
866
867
            if self.model_parallel_memory_opt:
                encoder_output = mpu.scatter_to_sequence_parallel_region(encoder_output)

868
        # Forward pass.
869
        if self.activations_checkpoint_method is not None:
870
            hidden_states = self._checkpointed_forward(hidden_states,
871
872
873
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
874
        else:
Mohammad's avatar
Mohammad committed
875
876
            for index in range(self.num_layers):
                layer = self._get_layer(index)
877
878
879
880
881
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
882
883
                    inference_params=inference_params)

884

885
        # Final layer norm.
886
        if self.post_process:
887
            # Reverting data format change [s b h] --> [b s h].
888
889
            hidden_states = self.final_layernorm(hidden_states)

890
891
892
893
894
895
896
            if self.layer_type==LayerType.encoder and \
                    self.model_type==ModelType.encoder_and_decoder and \
                    self.model_parallel_memory_opt:
                output = hidden_states
            else:
                if self.model_parallel_memory_opt:
                    hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
897

898
                output = hidden_states.transpose(0, 1).contiguous()
899
900
        else:
            output = hidden_states
901

902

903
        return output