layer_norm.py 18.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
Tri Dao's avatar
Tri Dao committed
3

4
5
6
7
8
9
import torch
from torch.nn import init

import dropout_layer_norm


10
11
12
13
14
15
16
17
def maybe_align(x, alignment_in_bytes=16):
    """Assume that x already has last dim divisible by alignment_in_bytes
    """
    # TD [2023-07-04] I'm not 100% sure that clone will align the memory
    # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
    return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()


Tri Dao's avatar
Tri Dao committed
18
19
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
                                    epsilon, residual_in_fp32=False, is_rms_norm=False):
20
    """ Assume that arguments are contiguous and aligned to 16 bytes
21
22
23
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
24
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
25
26
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
27
        x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
28
        1.0, 0, None, residual_in_fp32, is_rms_norm
29
30
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
31
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
32
33
34
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
35
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
Tri Dao's avatar
Tri Dao committed
36
                                     dropout_p, has_residual, is_rms_norm=False):
37
    """ Assume that arguments are contiguous and aligned to 16 bytes
Tri Dao's avatar
Tri Dao committed
38
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
39
    (x = drop(x0) + residual was not returned in the fwd).
Tri Dao's avatar
Tri Dao committed
40
    x0 must not be None if we have colscale.
41
42
43
44
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
Tri Dao's avatar
Tri Dao committed
45
46
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
47
    rowscale = rowscale.view(-1) if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
48
49
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
Tri Dao's avatar
Tri Dao committed
50
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
51
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
Tri Dao's avatar
Tri Dao committed
52
        dropout_p, 1.0, 0, has_residual, is_rms_norm
53
    )
Tri Dao's avatar
Tri Dao committed
54
    # dresidualmat is None if not has_residual
55
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
56
        return dx0mat, dresidualmat, dgamma, dbeta
57
58
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
59
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
60
61


Tri Dao's avatar
Tri Dao committed
62
63
64
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
                                           out_subset, dropout_p, epsilon, rowscale_const,
                                           out_numrows, residual_in_fp32=False, is_rms_norm=False):
65
    """ Assume that arguments are contiguous and aligned to 16 bytes
66
67
68
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
69
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
70
71
72
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
73
        x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
74
        rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
75
76
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
77
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
78
79
80
81
82
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
                                            x0_subset, out_subset, dropout_p, rowscale_const,
Tri Dao's avatar
Tri Dao committed
83
                                            x0_numrows, has_residual, is_rms_norm=False):
84
    """ Assume that arguments are contiguous and aligned to 16 bytes
85
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
86
    (x = drop(x0) + residual was not returned in the fwd).
87
88
89
90
91
92
93
94
95
96
97
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(-1, hidden_size)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
Tri Dao's avatar
Tri Dao committed
98
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
99
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
Tri Dao's avatar
Tri Dao committed
100
        dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
101
    )
Tri Dao's avatar
Tri Dao committed
102
    # dresidualmat is None if not has_residual
Tri Dao's avatar
Tri Dao committed
103
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
104
        return dx0mat, dresidualmat, dgamma, dbeta
Tri Dao's avatar
Tri Dao committed
105
106
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
107
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
108
109


110
111
112
113
def _dropout_add_layer_norm_parallel_residual_forward(
    x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
    epsilon, residual_in_fp32=False, is_rms_norm=False
):
114
    """ Assume that arguments are contiguous and aligned to 16 bytes
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    """
    hidden_size = gamma0.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
    z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
        x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
        None, residual_in_fp32, is_rms_norm
    )
    # dmask0 and dmask1 are None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
    return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma


def _dropout_add_layer_norm_parallel_residual_backward(
    dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
    dropout_p, has_x1, has_residual, is_rms_norm=False
):
133
    """ Assume that arguments are contiguous and aligned to 16 bytes
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + residual was not returned in the fwd).
    """
    hidden_size = gamma0.numel()
    xmat = x.view((-1, hidden_size))
    dz0mat = dz0.view(xmat.shape)
    dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
    dxmat = dx.view(xmat.shape) if dx is not None else None
    dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
        dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
        dropout_p, has_x1, has_residual, is_rms_norm
    )
    # dresidualmat is None if not has_residual
    return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1


Tri Dao's avatar
Tri Dao committed
150
class DropoutAddLayerNormFn(torch.autograd.Function):
151
    @staticmethod
Tri Dao's avatar
Tri Dao committed
152
    def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
153
                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
154
155
156
157
158
159
        x0 = maybe_align(x0.contiguous(), 16)
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma = maybe_align(gamma.contiguous(), 16)
        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
        rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
160
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
Tri Dao's avatar
Tri Dao committed
161
            x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
162
            residual_in_fp32, is_rms_norm
163
        )
Tri Dao's avatar
Tri Dao committed
164
165
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
166
        ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
Tri Dao's avatar
Tri Dao committed
167
        ctx.prenorm = prenorm
168
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
169
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
170
171
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
172
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
173
174
            return (zmat.view(x0.shape) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape)))
175
176
177
178
        else:
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
179
180
            return ((zmat.view(x0.shape), dmask) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
181
182
183
184

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
185
186
        dz = maybe_align(dz.contiguous(), 16)  # this happens!
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
Tri Dao's avatar
Tri Dao committed
187
188
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
189
190
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
191
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
Tri Dao's avatar
Tri Dao committed
192
193
            dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
            ctx.is_rms_norm
194
195
        )
        dx0 = dx0mat.view(x.shape)
Tri Dao's avatar
Tri Dao committed
196
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
Tri Dao's avatar
Tri Dao committed
197
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
198
199
        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
                None, None, None, None, None)
200
201


202
203
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
204
    def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
205
206
                rowscale_const, out_numrows, residual_in_fp32=False,
                prenorm=False, is_rms_norm=False, return_dmask=False):
207
208
209
210
211
        x0 = maybe_align(x0.contiguous(), 16)
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma = maybe_align(gamma.contiguous(), 16)
        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
212
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
Tri Dao's avatar
Tri Dao committed
213
            x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
214
            rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
215
216
217
218
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        x_shape = (-1, *x0.shape[1:])
219
        ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
220
221
222
223
224
                              x0_subset, out_subset)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.rowscale_const = rowscale_const
        ctx.x0_numrows = x0.shape[:-1].numel()
Tri Dao's avatar
Tri Dao committed
225
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
226
227
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        z_shape = (-1, *x0.shape[1:])
        if not return_dmask:
            return (zmat.view(z_shape) if not prenorm
                    else (zmat.view(z_shape), xmat.view(x0.shape)))
        else:
            z = zmat.view(z_shape)
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
            return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
242
243
        dz = maybe_align(dz.contiguous(), 16)  # this happens!
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
244
245
246
247
        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
248
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
249
            dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
Tri Dao's avatar
Tri Dao committed
250
            ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
251
252
        )
        dx0 = dx0mat.view(-1, *x.shape[1:])
Tri Dao's avatar
Tri Dao committed
253
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
254
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
255
256
        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
                None, None, None, None, None, None, None, None)
257
258


259
260
261
262
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
263
264
265
266
267
268
269
        x0 = maybe_align(x0.contiguous(), 16)
        x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma0 = maybe_align(gamma0.contiguous(), 16)
        beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
        gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
        beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
            x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
            residual_in_fp32, is_rms_norm
        )
        ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.has_x1 = x1 is not None
        ctx.has_residual = residual is not None
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta0 is not None
        z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
        if not return_dmask:
            return z if not prenorm else (*z, xmat.view(x0.shape))
        else:
            dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask0)
            ctx.mark_non_differentiable(dmask1)
            return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)

    @staticmethod
    def backward(ctx, dz0, dz1, *args):
295
296
297
        dz0 = maybe_align(dz0.contiguous(), 16)  # this happens!
        dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
        dropout_p = ctx.dropout_p
        has_x1 = ctx.has_x1
        has_residual = ctx.has_residual
        dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
            dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
            has_residual, ctx.is_rms_norm
        )
        dx0 = dx0mat.view(x.shape)
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
        return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
                dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)


Tri Dao's avatar
Tri Dao committed
313
314
315
316
def layer_norm(x, weight, bias, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)


Tri Dao's avatar
Tri Dao committed
317
318
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                           layerscale=None, prenorm=False, residual_in_fp32=False,
319
                           return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
320
321
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
322
    """
Tri Dao's avatar
Tri Dao committed
323
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
324
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
325
        False, return_dropout_mask
Tri Dao's avatar
Tri Dao committed
326
    )
327
328


Tri Dao's avatar
Tri Dao committed
329
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
330
331
332
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
333
334
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
335
336
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
337
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
338
        rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
339
340
341
342
343
344
345
346
347
348
349
350
351
    )


def dropout_add_layer_norm_parallel_residual(
    x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
    residual_in_fp32=False, return_dropout_mask=False
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
        False, return_dropout_mask
352
353
354
    )


355
class DropoutAddLayerNorm(torch.nn.Module):
356
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
357
358
359
360
361
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
Tri Dao's avatar
Tri Dao committed
362
        self.eps = eps
363
364
365
366
367
368
369
370
371
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

Tri Dao's avatar
Tri Dao committed
372
373
    def forward(self, x0, residual=None):
        return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
Tri Dao's avatar
Tri Dao committed
374
                                      self.p if self.training else 0.0, self.eps,
375
                                      prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)