fused_layer_norm.py 5.24 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
5
   with some changes. """
6
7

import numbers
8
import torch
9
10
11
12
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

13
from megatron.core.utils import make_viewless_tensor
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
14

15
16
17
18
19
try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False
Sangkug Lym's avatar
Sangkug Lym committed
20

21
22
23
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None

24

25
26
27
28
class FusedLayerNormAffineFunction(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
29

30
31
32
33
34
35
36
37
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input_ = input.contiguous()
    weight_ = weight.contiguous()
    bias_ = bias.contiguous()
    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
38

39
40
    return output

41

42
43
  @staticmethod
  def backward(ctx, grad_output):
44

45
46
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
47
48
    grad_input, grad_weight, grad_bias \
      = fused_mix_prec_layer_norm_cuda.backward_affine(
49
50
51
52
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)

53
    return grad_input, grad_weight, grad_bias, None, None
54
55
56
57



class MixedFusedLayerNorm(torch.nn.Module):
58

59
60
  def __init__(self, normalized_shape, eps=1e-5,
               no_persist_layer_norm=True,
Mostofa Patwary's avatar
Mostofa Patwary committed
61
62
               sequence_parallel=False,
               apply_layernorm_1p=False):
63
64
        super(MixedFusedLayerNorm, self).__init__()

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
68
69
70
71
        self.apply_layernorm_1p = False
        if apply_layernorm_1p:
            self.weight_adjustment = 1
            self.apply_layernorm_1p = True
        else:
            self.weight_adjustment = 0

72
        global fused_mix_prec_layer_norm_cuda
73
74
        fused_mix_prec_layer_norm_cuda = importlib.import_module(
          "fused_mix_prec_layer_norm_cuda")
75

Sangkug Lym's avatar
Sangkug Lym committed
76
77
78
79
80
81
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
82
83
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
Sangkug Lym's avatar
Sangkug Lym committed
84
85
            no_persist_layer_norm = True

86
87
88
89
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
90
91
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
92
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
93
        self.no_persist_layer_norm = no_persist_layer_norm
94
95
96
        self.sequence_parallel = sequence_parallel
        
        # set sequence parallelism flag on weight and bias parameters
97
98
        setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
        setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
Mostofa Patwary's avatar
Mostofa Patwary committed
99
        
100
101
102

  def reset_parameters(self):

Mostofa Patwary's avatar
Mostofa Patwary committed
103
104
105
106
107
108
    if self.apply_layernorm_1p:
        init.zeros_(self.weight)
        init.zeros_(self.bias)
    else: 
        init.ones_(self.weight)
        init.zeros_(self.bias)
109
110
111

  def forward(self, input):

Sangkug Lym's avatar
Sangkug Lym committed
112
113
    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
Mostofa Patwary's avatar
Mostofa Patwary committed
114
115
          input, self.weight + self.weight_adjustment, \
            self.bias, self.normalized_shape, self.eps)
Sangkug Lym's avatar
Sangkug Lym committed
116
    else:
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
117
        output = FastLayerNormFN.apply(
Mostofa Patwary's avatar
Mostofa Patwary committed
118
          input, self.weight + self.weight_adjustment, self.bias, self.eps)
119

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
120
121
122
123
124
125
126
127
128
        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output
Mostofa Patwary's avatar
Mostofa Patwary committed
129
130
131



Mostofa Patwary's avatar
Mostofa Patwary committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#class MixedFusedLayerNorm1P(MixedFusedLayerNorm):
#  def reset_parameters(self):
#    init.zeros_(self.weight)
#    init.zeros_(self.bias)
#
#  def forward(self, input):
#
#    if self.no_persist_layer_norm:
#        return FusedLayerNormAffineFunction.apply(
#          input, self.weight + 1, self.bias, self.normalized_shape, self.eps)
#    else:
#        output = FastLayerNormFN.apply(
#          input, self.weight + 1, self.bias, self.eps)
#
#        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
#        # a populated '_base' field). This will result in schedule.py's
#        # deallocate_output_tensor() throwing an error, so a viewless tensor is
#        # created to prevent this.
#        output = make_viewless_tensor(inp = output,
#                                      requires_grad = input.requires_grad,
#                                      keep_graph = True)
#
#        return output