fused_layer_norm.py 4.32 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
5
   with some changes. """
6
7

import numbers
8
import torch
9
10
11
12
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

13
from megatron.core.utils import make_viewless_tensor
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
14

15
16
17
18
19
try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False
Sangkug Lym's avatar
Sangkug Lym committed
20

21
22
23
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None

24

25
26
27
28
class FusedLayerNormAffineFunction(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
29

30
31
32
33
34
35
36
37
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input_ = input.contiguous()
    weight_ = weight.contiguous()
    bias_ = bias.contiguous()
    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
38

39
40
    return output

41

42
43
  @staticmethod
  def backward(ctx, grad_output):
44

45
46
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
47
48
    grad_input, grad_weight, grad_bias \
      = fused_mix_prec_layer_norm_cuda.backward_affine(
49
50
51
52
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)

53
    return grad_input, grad_weight, grad_bias, None, None
54
55
56
57



class MixedFusedLayerNorm(torch.nn.Module):
58

59
60
  def __init__(self, normalized_shape, eps=1e-5,
               no_persist_layer_norm=True,
Mostofa Patwary's avatar
Mostofa Patwary committed
61
62
               sequence_parallel=False,
               apply_layernorm_1p=False):
63
64
        super(MixedFusedLayerNorm, self).__init__()

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
68
69
70
71
        self.apply_layernorm_1p = False
        if apply_layernorm_1p:
            self.weight_adjustment = 1
            self.apply_layernorm_1p = True
        else:
            self.weight_adjustment = 0

72
        global fused_mix_prec_layer_norm_cuda
73
74
        fused_mix_prec_layer_norm_cuda = importlib.import_module(
          "fused_mix_prec_layer_norm_cuda")
75

Sangkug Lym's avatar
Sangkug Lym committed
76
77
78
79
80
81
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
82
83
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
Sangkug Lym's avatar
Sangkug Lym committed
84
85
            no_persist_layer_norm = True

86
87
88
89
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
90
91
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
92
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
93
        self.no_persist_layer_norm = no_persist_layer_norm
94
95
96
        self.sequence_parallel = sequence_parallel
        
        # set sequence parallelism flag on weight and bias parameters
97
98
        setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
        setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
Mostofa Patwary's avatar
Mostofa Patwary committed
99

100
101
102

  def reset_parameters(self):

Mostofa Patwary's avatar
Mostofa Patwary committed
103
104
105
    if self.apply_layernorm_1p:
        init.zeros_(self.weight)
        init.zeros_(self.bias)
Mostofa Patwary's avatar
Mostofa Patwary committed
106
    else:
Mostofa Patwary's avatar
Mostofa Patwary committed
107
108
        init.ones_(self.weight)
        init.zeros_(self.bias)
109
110
111

  def forward(self, input):

Sangkug Lym's avatar
Sangkug Lym committed
112
113
    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
Mostofa Patwary's avatar
Mostofa Patwary committed
114
115
          input, self.weight + self.weight_adjustment, \
            self.bias, self.normalized_shape, self.eps)
Sangkug Lym's avatar
Sangkug Lym committed
116
    else:
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
117
        output = FastLayerNormFN.apply(
Mostofa Patwary's avatar
Mostofa Patwary committed
118
          input, self.weight + self.weight_adjustment, self.bias, self.eps)
119

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
120
121
122
123
124
125
126
127
128
        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output