fused_layer_norm.py 3.48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This code is copied fron NVIDIA apex:
      https://github.com/NVIDIA/apex
18
   with some changes. """
19
20

import numbers
21
import torch
22
23
24
25
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib

Sangkug Lym's avatar
Sangkug Lym committed
26
27
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN

28
29
30
global fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = None

31

32
33
34
35
class FusedLayerNormAffineFunction(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input, weight, bias, normalized_shape, eps):
36

37
38
39
40
41
42
43
44
    ctx.normalized_shape = normalized_shape
    ctx.eps = eps
    input_ = input.contiguous()
    weight_ = weight.contiguous()
    bias_ = bias.contiguous()
    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
45

46
47
    return output

48

49
50
  @staticmethod
  def backward(ctx, grad_output):
51

52
53
    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
54
55
    grad_input, grad_weight, grad_bias \
      = fused_mix_prec_layer_norm_cuda.backward_affine(
56
57
58
59
        grad_output.contiguous(), mean, invvar,
        input_, ctx.normalized_shape,
        weight_, bias_, ctx.eps)

60
    return grad_input, grad_weight, grad_bias, None, None
61
62
63
64



class MixedFusedLayerNorm(torch.nn.Module):
65

Sangkug Lym's avatar
Sangkug Lym committed
66
  def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True):
67
68
69
        super(MixedFusedLayerNorm, self).__init__()

        global fused_mix_prec_layer_norm_cuda
70
71
        fused_mix_prec_layer_norm_cuda = importlib.import_module(
          "fused_mix_prec_layer_norm_cuda")
72

Sangkug Lym's avatar
Sangkug Lym committed
73
74
75
76
77
78
79
80
81
        # List of hiddens sizes supported in the persistent layer norm kernel
        # If the hidden size is not supported, fall back to the non-persistent
        # kernel.
        persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
            5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
            24576, 25600, 30720, 32768, 40960, 49152, 65536]
        if normalized_shape not in persist_ln_hidden_sizes:
            no_persist_layer_norm = True

82
83
84
85
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = torch.Size(normalized_shape)
        self.eps = eps
86
87
        self.weight = Parameter(torch.Tensor(*normalized_shape))
        self.bias = Parameter(torch.Tensor(*normalized_shape))
88
        self.reset_parameters()
Sangkug Lym's avatar
Sangkug Lym committed
89
        self.no_persist_layer_norm = no_persist_layer_norm
90

91
92
93
94
95
96
97
98
99

  def reset_parameters(self):

    init.ones_(self.weight)
    init.zeros_(self.bias)


  def forward(self, input):

Sangkug Lym's avatar
Sangkug Lym committed
100
101
102
103
104
105
    if self.no_persist_layer_norm:
        return FusedLayerNormAffineFunction.apply(
          input, self.weight, self.bias, self.normalized_shape, self.eps)
    else:
        return FastLayerNormFN.apply(
          input, self.weight, self.bias, self.eps)
106