fused_layer_norm.py 5.03 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
5
   with some changes. """
6
7

import numbers
8
import torch
9
10
11
12
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

Mostofa Patwary's avatar
Mostofa Patwary committed
13
from megatron import get_args
14
from megatron.core.utils import make_viewless_tensor
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
15

16
17
18
19
20
try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False
Sangkug Lym's avatar
Sangkug Lym committed
21

22
23
24
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None

25

26
27
28
29
class FusedLayerNormAffineFunction(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
30

31
32
33
34
35
36
37
38
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input_ = input.contiguous()
    weight_ = weight.contiguous()
    bias_ = bias.contiguous()
    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
39

40
41
    return output

42

43
44
  @staticmethod
  def backward(ctx, grad_output):
45

46
47
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
48
49
    grad_input, grad_weight, grad_bias \
      = fused_mix_prec_layer_norm_cuda.backward_affine(
50
51
52
53
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)

54
    return grad_input, grad_weight, grad_bias, None, None
55
56
57
58



class MixedFusedLayerNorm(torch.nn.Module):
59

60
61
62
  def __init__(self, normalized_shape, eps=1e-5,
               no_persist_layer_norm=True,
               sequence_parallel=False):
63
64
65
        super(MixedFusedLayerNorm, self).__init__()

        global fused_mix_prec_layer_norm_cuda
66
67
        fused_mix_prec_layer_norm_cuda = importlib.import_module(
          "fused_mix_prec_layer_norm_cuda")
68

Sangkug Lym's avatar
Sangkug Lym committed
69
70
71
72
73
74
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
75
76
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
Sangkug Lym's avatar
Sangkug Lym committed
77
78
            no_persist_layer_norm = True

79
80
81
82
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
83
84
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
85
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
86
        self.no_persist_layer_norm = no_persist_layer_norm
87
88
89
        self.sequence_parallel = sequence_parallel
        
        # set sequence parallelism flag on weight and bias parameters
90
91
        setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
        setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
92

Mostofa Patwary's avatar
Mostofa Patwary committed
93
94
95
96
        args = get_args()
        self.weight_adjustment = 0
        if args.apply_layernorm_1p:
            self.weight_adjustment = 1
97
98
99
100
101
102
103
104
105

  def reset_parameters(self):

    init.ones_(self.weight)
    init.zeros_(self.bias)


  def forward(self, input):

Sangkug Lym's avatar
Sangkug Lym committed
106
107
    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
Mostofa Patwary's avatar
Mostofa Patwary committed
108
          input, self.weight + self.weight_adjustment, self.bias, self.normalized_shape, self.eps)
Sangkug Lym's avatar
Sangkug Lym committed
109
    else:
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
110
        output = FastLayerNormFN.apply(
Mostofa Patwary's avatar
Mostofa Patwary committed
111
          input, self.weight + self.weight_adjustment, self.bias, self.eps)
112

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
113
114
115
116
117
118
119
120
121
        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output
Mostofa Patwary's avatar
Mostofa Patwary committed
122
123
124



Mostofa Patwary's avatar
Mostofa Patwary committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#class MixedFusedLayerNorm1P(MixedFusedLayerNorm):
#  def reset_parameters(self):
#    init.zeros_(self.weight)
#    init.zeros_(self.bias)
#
#  def forward(self, input):
#
#    if self.no_persist_layer_norm:
#        return FusedLayerNormAffineFunction.apply(
#          input, self.weight + 1, self.bias, self.normalized_shape, self.eps)
#    else:
#        output = FastLayerNormFN.apply(
#          input, self.weight + 1, self.bias, self.eps)
#
#        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
#        # a populated '_base' field). This will result in schedule.py's
#        # deallocate_output_tensor() throwing an error, so a viewless tensor is
#        # created to prevent this.
#        output = make_viewless_tensor(inp = output,
#                                      requires_grad = input.requires_grad,
#                                      keep_graph = True)
#
#        return output