fused_layer_norm.py 3.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
5
   with some changes. """
6
7

import numbers
8
import torch
9
10
11
12
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
13
14
from megatron.mpu import make_viewless_tensor

15
16
17
18
19
try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False
Sangkug Lym's avatar
Sangkug Lym committed
20

21
22
23
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None

24

25
26
27
28
class FusedLayerNormAffineFunction(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
29

30
31
32
33
34
35
36
37
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input_ = input.contiguous()
    weight_ = weight.contiguous()
    bias_ = bias.contiguous()
    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
38

39
40
    return output

41

42
43
  @staticmethod
  def backward(ctx, grad_output):
44

45
46
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
47
48
    grad_input, grad_weight, grad_bias \
      = fused_mix_prec_layer_norm_cuda.backward_affine(
49
50
51
52
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)

53
    return grad_input, grad_weight, grad_bias, None, None
54
55
56
57



class MixedFusedLayerNorm(torch.nn.Module):
58

59
60
61
  def __init__(self, normalized_shape, eps=1e-5,
               no_persist_layer_norm=True,
               sequence_parallel=False):
62
63
64
        super(MixedFusedLayerNorm, self).__init__()

        global fused_mix_prec_layer_norm_cuda
65
66
        fused_mix_prec_layer_norm_cuda = importlib.import_module(
          "fused_mix_prec_layer_norm_cuda")
67

Sangkug Lym's avatar
Sangkug Lym committed
68
69
70
71
72
73
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
74
75
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
Sangkug Lym's avatar
Sangkug Lym committed
76
77
            no_persist_layer_norm = True

78
79
80
81
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
82
83
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
84
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
85
        self.no_persist_layer_norm = no_persist_layer_norm
86
87
88
        self.sequence_parallel = sequence_parallel
        
        # set sequence parallelism flag on weight and bias parameters
89
90
        setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
        setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
91

92
93
94
95
96
97
98
99
100

  def reset_parameters(self):

    init.ones_(self.weight)
    init.zeros_(self.bias)


  def forward(self, input):

Sangkug Lym's avatar
Sangkug Lym committed
101
102
103
104
    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
          input, self.weight, self.bias, self.normalized_shape, self.eps)
    else:
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
105
        output = FastLayerNormFN.apply(
Sangkug Lym's avatar
Sangkug Lym committed
106
          input, self.weight, self.bias, self.eps)
107

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
108
109
110
111
112
113
114
115
116
        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output