layer_norm.py 12.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
Tri Dao's avatar
Tri Dao committed
3

4
5
6
7
8
9
import torch
from torch.nn import init

import dropout_layer_norm


Tri Dao's avatar
Tri Dao committed
10
11
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
                                    epsilon, residual_in_fp32=False, is_rms_norm=False):
12
13
14
15
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
16
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
17
18
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
19
        x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
20
        1.0, 0, None, residual_in_fp32, is_rms_norm
21
22
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
23
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
24
25
26
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
27
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
Tri Dao's avatar
Tri Dao committed
28
                                     dropout_p, has_residual, is_rms_norm=False):
29
    """ Assume that arguments are contiguous
Tri Dao's avatar
Tri Dao committed
30
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
31
    (x = drop(x0) + residual was not returned in the fwd).
Tri Dao's avatar
Tri Dao committed
32
    x0 must not be None if we have colscale.
33
34
35
36
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
Tri Dao's avatar
Tri Dao committed
37
38
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
39
    rowscale = rowscale.view(-1) if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
40
41
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
Tri Dao's avatar
Tri Dao committed
42
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
43
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
Tri Dao's avatar
Tri Dao committed
44
        dropout_p, 1.0, 0, has_residual, is_rms_norm
45
    )
Tri Dao's avatar
Tri Dao committed
46
    # dresidualmat is None if not has_residual
47
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
48
        return dx0mat, dresidualmat, dgamma, dbeta
49
50
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
51
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
52
53


Tri Dao's avatar
Tri Dao committed
54
55
56
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
                                           out_subset, dropout_p, epsilon, rowscale_const,
                                           out_numrows, residual_in_fp32=False, is_rms_norm=False):
57
58
59
60
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
61
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
62
63
64
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
65
        x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
66
        rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
67
68
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
69
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
70
71
72
73
74
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
                                            x0_subset, out_subset, dropout_p, rowscale_const,
Tri Dao's avatar
Tri Dao committed
75
                                            x0_numrows, has_residual, is_rms_norm=False):
76
77
    """ Assume that arguments are contiguous
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
78
    (x = drop(x0) + residual was not returned in the fwd).
79
80
81
82
83
84
85
86
87
88
89
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(-1, hidden_size)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
Tri Dao's avatar
Tri Dao committed
90
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
91
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
Tri Dao's avatar
Tri Dao committed
92
        dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
93
    )
Tri Dao's avatar
Tri Dao committed
94
    # dresidualmat is None if not has_residual
Tri Dao's avatar
Tri Dao committed
95
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
96
        return dx0mat, dresidualmat, dgamma, dbeta
Tri Dao's avatar
Tri Dao committed
97
98
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
99
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
100
101


Tri Dao's avatar
Tri Dao committed
102
class DropoutAddLayerNormFn(torch.autograd.Function):
103
    @staticmethod
Tri Dao's avatar
Tri Dao committed
104
    def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
105
                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
106
        x0 = x0.contiguous()
Tri Dao's avatar
Tri Dao committed
107
        residual = residual.contiguous() if residual is not None else None
108
        gamma = gamma.contiguous()
Tri Dao's avatar
Tri Dao committed
109
        beta = beta.contiguous() if beta is not None else None
110
        rowscale = rowscale.contiguous() if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
111
        colscale = colscale.contiguous() if colscale is not None else None
112
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
Tri Dao's avatar
Tri Dao committed
113
            x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
114
            residual_in_fp32, is_rms_norm
115
        )
Tri Dao's avatar
Tri Dao committed
116
117
118
119
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma, rowscale, colscale)
        ctx.prenorm = prenorm
120
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
121
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
122
123
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
124
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
125
126
            return (zmat.view(x0.shape) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape)))
127
128
129
130
        else:
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
131
132
            return ((zmat.view(x0.shape), dmask) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
133
134
135
136
137

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
Tri Dao's avatar
Tri Dao committed
138
139
140
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
141
142
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
143
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
Tri Dao's avatar
Tri Dao committed
144
145
            dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
            ctx.is_rms_norm
146
147
        )
        dx0 = dx0mat.view(x.shape)
Tri Dao's avatar
Tri Dao committed
148
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
Tri Dao's avatar
Tri Dao committed
149
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
150
151
        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
                None, None, None, None, None)
152
153


154
155
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
156
    def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
157
158
                rowscale_const, out_numrows, residual_in_fp32=False,
                prenorm=False, is_rms_norm=False, return_dmask=False):
159
        x0 = x0.contiguous()
Tri Dao's avatar
Tri Dao committed
160
        residual = residual.contiguous() if residual is not None else None
161
        gamma = gamma.contiguous()
Tri Dao's avatar
Tri Dao committed
162
        beta = beta.contiguous() if beta is not None else None
163
164
        colscale = colscale.contiguous() if colscale is not None else None
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
Tri Dao's avatar
Tri Dao committed
165
            x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
166
            rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
167
168
169
170
171
172
173
174
175
176
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        x_shape = (-1, *x0.shape[1:])
        ctx.save_for_backward(xmat.view(x_shape), x0, dmask, gamma, mu, rsigma, colscale,
                              x0_subset, out_subset)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.rowscale_const = rowscale_const
        ctx.x0_numrows = x0.shape[:-1].numel()
Tri Dao's avatar
Tri Dao committed
177
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
178
179
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        z_shape = (-1, *x0.shape[1:])
        if not return_dmask:
            return (zmat.view(z_shape) if not prenorm
                    else (zmat.view(z_shape), xmat.view(x0.shape)))
        else:
            z = zmat.view(z_shape)
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
            return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
200
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
201
            dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
Tri Dao's avatar
Tri Dao committed
202
            ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
203
204
        )
        dx0 = dx0mat.view(-1, *x.shape[1:])
Tri Dao's avatar
Tri Dao committed
205
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
206
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
207
208
        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
                None, None, None, None, None, None, None, None)
209
210


Tri Dao's avatar
Tri Dao committed
211
212
213
214
def layer_norm(x, weight, bias, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)


Tri Dao's avatar
Tri Dao committed
215
216
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                           layerscale=None, prenorm=False, residual_in_fp32=False,
217
                           return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
218
219
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
220
    """
Tri Dao's avatar
Tri Dao committed
221
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
222
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
223
        False, return_dropout_mask
Tri Dao's avatar
Tri Dao committed
224
    )
225
226


Tri Dao's avatar
Tri Dao committed
227
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
228
229
230
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
231
232
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
233
234
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
235
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
236
        rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
237
238
239
    )


240
class DropoutAddLayerNorm(torch.nn.Module):
241
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.epsilon = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

Tri Dao's avatar
Tri Dao committed
257
258
    def forward(self, x0, residual=None):
        return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
259
260
                                      self.p if self.training else 0.0, self.epsilon,
                                      prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)