fused_layer_norm.py 4.79 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
5
   with some changes. """
6
7

import numbers
8
import torch
9
10
11
12
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

13
from megatron.core.utils import make_viewless_tensor
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
14

15
16
17
18
19
try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False
Sangkug Lym's avatar
Sangkug Lym committed
20

21
22
23
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None

24

25
26
27
28
class FusedLayerNormAffineFunction(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
29

30
31
32
33
34
35
36
37
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input_ = input.contiguous()
    weight_ = weight.contiguous()
    bias_ = bias.contiguous()
    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
38

39
40
    return output

41

42
43
  @staticmethod
  def backward(ctx, grad_output):
44

45
46
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
47
48
    grad_input, grad_weight, grad_bias \
      = fused_mix_prec_layer_norm_cuda.backward_affine(
49
50
51
52
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)

53
    return grad_input, grad_weight, grad_bias, None, None
54
55
56
57



class MixedFusedLayerNorm(torch.nn.Module):
58

59
60
61
  def __init__(self, normalized_shape, eps=1e-5,
               no_persist_layer_norm=True,
               sequence_parallel=False):
62
63
64
        super(MixedFusedLayerNorm, self).__init__()

        global fused_mix_prec_layer_norm_cuda
65
66
        fused_mix_prec_layer_norm_cuda = importlib.import_module(
          "fused_mix_prec_layer_norm_cuda")
67

Sangkug Lym's avatar
Sangkug Lym committed
68
69
70
71
72
73
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
74
75
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
Sangkug Lym's avatar
Sangkug Lym committed
76
77
            no_persist_layer_norm = True

78
79
80
81
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
82
83
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
84
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
85
        self.no_persist_layer_norm = no_persist_layer_norm
86
87
88
        self.sequence_parallel = sequence_parallel
        
        # set sequence parallelism flag on weight and bias parameters
89
90
        setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
        setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
91

92
93
94
95
96
97
98
99
100

  def reset_parameters(self):

    init.ones_(self.weight)
    init.zeros_(self.bias)


  def forward(self, input):

Sangkug Lym's avatar
Sangkug Lym committed
101
102
103
104
    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
          input, self.weight, self.bias, self.normalized_shape, self.eps)
    else:
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
105
        output = FastLayerNormFN.apply(
Sangkug Lym's avatar
Sangkug Lym committed
106
          input, self.weight, self.bias, self.eps)
107

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
108
109
110
111
112
113
114
115
116
        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output
Mostofa Patwary's avatar
Mostofa Patwary committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142



class MixedFusedLayerNorm1P(MixedFusedLayerNorm):
  def reset_parameters(self):
    init.zeros_(self.weight)
    init.zeros_(self.bias)

  def forward(self, input):

    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
          input, self.weight + 1, self.bias, self.normalized_shape, self.eps)
    else:
        output = FastLayerNormFN.apply(
          input, self.weight + 1, self.bias, self.eps)

        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output