layer_norm.py 5.96 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
Tri Dao's avatar
Tri Dao committed
3

4
5
6
7
8
9
import torch
from torch.nn import init

import dropout_layer_norm


Tri Dao's avatar
Tri Dao committed
10
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon,
11
12
13
14
15
16
17
18
                                    residual_in_fp32):
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
19
        x0mat, x1mat, gamma, beta, rowscale, colscale, dropout_p, epsilon, None, residual_in_fp32
20
21
22
23
24
25
    )
    # dmask is None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
26
27
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
                                     dropout_p, has_residual):
28
    """ Assume that arguments are contiguous
Tri Dao's avatar
Tri Dao committed
29
30
31
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + x1 was not returned in the fwd).
    x0 must not be None if we have colscale.
32
33
34
35
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
Tri Dao's avatar
Tri Dao committed
36
37
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
38
    rowscale = rowscale.view(-1) if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
39
40
41
42
43
44
    colscale = colscale.view(-1) if colscale is not None else None
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
    dx0mat, dx1mat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p,
        has_residual
45
46
    )
    # dx1mat is None if not has_residual
Tri Dao's avatar
Tri Dao committed
47
48
49
50
51
    if colscale is None:
        return dx0mat, dx1mat, dgamma, dbeta
    else:
        dcolscale = rest[0]
        return dx0mat, dx1mat, dgamma, dbeta, dcolscale
52
53


Tri Dao's avatar
Tri Dao committed
54
class DropoutAddLayerNormFn(torch.autograd.Function):
55
    @staticmethod
Tri Dao's avatar
Tri Dao committed
56
57
    def forward(ctx, x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32,
                prenorm=False, return_dmask=False):
58
59
60
61
62
        x0 = x0.contiguous()
        x1 = x1.contiguous() if x1 is not None else None
        gamma = gamma.contiguous()
        beta = beta.contiguous()
        rowscale = rowscale.contiguous() if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
63
        colscale = colscale.contiguous() if colscale is not None else None
64
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
Tri Dao's avatar
Tri Dao committed
65
            x0, x1, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32
66
        )
Tri Dao's avatar
Tri Dao committed
67
68
69
70
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
        ctx.prenorm = prenorm
71
72
73
        ctx.dropout_p = dropout_p
        ctx.has_residual = x1 is not None
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
74
75
            return (zmat.view(x0.shape) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape)))
76
77
78
79
        else:
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
80
81
            return ((zmat.view(x0.shape), dmask) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
82
83
84
85
86

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
Tri Dao's avatar
Tri Dao committed
87
88
89
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
90
91
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
92
93
        dx0mat, dx1mat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
            dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual
94
95
96
        )
        dx0 = dx0mat.view(x.shape)
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
Tri Dao's avatar
Tri Dao committed
97
98
        dcolscale = rest[0] if colscale is not None else None
        return dx0, dx1, dgamma, dbeta, None, dcolscale, None, None, None, None, None
99
100


Tri Dao's avatar
Tri Dao committed
101
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None,
102
103
104
105
106
                           prenorm=False, residual_in_fp32=False,
                           return_dropout_mask=False):
    """residual_in_fp32 only has an effect if x1 is None.
    Otherwise residual dtype is x1.dtype.
    """
Tri Dao's avatar
Tri Dao committed
107
108
109
110
    return DropoutAddLayerNormFn.apply(
        x0, x1, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
        return_dropout_mask
    )
111
112
113


class DropoutAddLayerNorm(torch.nn.Module):
114
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.epsilon = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

    def forward(self, x0, x1=None):
        return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
                                      self.p if self.training else 0.0, self.epsilon,
                                      prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)