fused_layer_norm.py 3.23 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
5
   with some changes. """
6
7

import numbers
8
import torch
9
10
11
12
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

13
from megatron.core.utils import make_viewless_tensor
Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
14

15
16
17
18
19
try:
    from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
    HAVE_PERSIST_LAYER_NORM = True
except:
    HAVE_PERSIST_LAYER_NORM = False
Sangkug Lym's avatar
Sangkug Lym committed
20

21
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
22

23

24
25
global fused_layer_norm_cuda
fused_layer_norm_cuda = None
26
27
28


class MixedFusedLayerNorm(torch.nn.Module):
29

30
31
  def __init__(self, normalized_shape, eps=1e-5,
               no_persist_layer_norm=True,
Mostofa Patwary's avatar
Mostofa Patwary committed
32
33
               sequence_parallel=False,
               apply_layernorm_1p=False):
34
35
        super(MixedFusedLayerNorm, self).__init__()

Jared Casper's avatar
Jared Casper committed
36
        self.apply_layernorm_1p = apply_layernorm_1p
Mostofa Patwary's avatar
Mostofa Patwary committed
37

38
39
        global fused_layer_norm_cuda
        fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
40

Sangkug Lym's avatar
Sangkug Lym committed
41
42
43
44
45
46
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
47
48
        if normalized_shape not in persist_ln_hidden_sizes or \
                not HAVE_PERSIST_LAYER_NORM:
Sangkug Lym's avatar
Sangkug Lym committed
49
50
            no_persist_layer_norm = True

51
52
53
54
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
55
56
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
57
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
58
        self.no_persist_layer_norm = no_persist_layer_norm
59
        self.sequence_parallel = sequence_parallel
60

61
        # set sequence parallelism flag on weight and bias parameters
62
63
        setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
        setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
Mostofa Patwary's avatar
Mostofa Patwary committed
64

65
66
67

  def reset_parameters(self):

Mostofa Patwary's avatar
Mostofa Patwary committed
68
69
70
    if self.apply_layernorm_1p:
        init.zeros_(self.weight)
        init.zeros_(self.bias)
Mostofa Patwary's avatar
Mostofa Patwary committed
71
    else:
Mostofa Patwary's avatar
Mostofa Patwary committed
72
73
        init.ones_(self.weight)
        init.zeros_(self.bias)
74
75
76

  def forward(self, input):

Jared Casper's avatar
Jared Casper committed
77
78
    weight = self.weight + 1 if self.apply_layernorm_1p else self.weight

Sangkug Lym's avatar
Sangkug Lym committed
79
    if self.no_persist_layer_norm:
Jared Casper's avatar
Jared Casper committed
80
        return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
Sangkug Lym's avatar
Sangkug Lym committed
81
    else:
Jared Casper's avatar
Jared Casper committed
82
        output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
83

Lawrence McAfee's avatar
fixed.  
Lawrence McAfee committed
84
85
86
87
88
89
90
91
92
        # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
        # a populated '_base' field). This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        output = make_viewless_tensor(inp = output,
                                      requires_grad = input.requires_grad,
                                      keep_graph = True)

        return output