layer_norm.py 17.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
Tri Dao's avatar
Tri Dao committed
3

4
5
6
7
8
9
import torch
from torch.nn import init

import dropout_layer_norm


Tri Dao's avatar
Tri Dao committed
10
11
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
                                    epsilon, residual_in_fp32=False, is_rms_norm=False):
12
13
14
15
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
16
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
17
18
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
19
        x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
20
        1.0, 0, None, residual_in_fp32, is_rms_norm
21
22
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
23
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
24
25
26
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
27
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
Tri Dao's avatar
Tri Dao committed
28
                                     dropout_p, has_residual, is_rms_norm=False):
29
    """ Assume that arguments are contiguous
Tri Dao's avatar
Tri Dao committed
30
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
31
    (x = drop(x0) + residual was not returned in the fwd).
Tri Dao's avatar
Tri Dao committed
32
    x0 must not be None if we have colscale.
33
34
35
36
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
Tri Dao's avatar
Tri Dao committed
37
38
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
39
    rowscale = rowscale.view(-1) if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
40
41
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
Tri Dao's avatar
Tri Dao committed
42
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
43
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
Tri Dao's avatar
Tri Dao committed
44
        dropout_p, 1.0, 0, has_residual, is_rms_norm
45
    )
Tri Dao's avatar
Tri Dao committed
46
    # dresidualmat is None if not has_residual
47
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
48
        return dx0mat, dresidualmat, dgamma, dbeta
49
50
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
51
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
52
53


Tri Dao's avatar
Tri Dao committed
54
55
56
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
                                           out_subset, dropout_p, epsilon, rowscale_const,
                                           out_numrows, residual_in_fp32=False, is_rms_norm=False):
57
58
59
60
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
61
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
62
63
64
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
65
        x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
66
        rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
67
68
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
69
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
70
71
72
73
74
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
                                            x0_subset, out_subset, dropout_p, rowscale_const,
Tri Dao's avatar
Tri Dao committed
75
                                            x0_numrows, has_residual, is_rms_norm=False):
76
77
    """ Assume that arguments are contiguous
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
78
    (x = drop(x0) + residual was not returned in the fwd).
79
80
81
82
83
84
85
86
87
88
89
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(-1, hidden_size)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    if colscale is not None:
        assert x0 is not None, 'x0 is required to compute the gradient of colscale'
Tri Dao's avatar
Tri Dao committed
90
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
91
        dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
Tri Dao's avatar
Tri Dao committed
92
        dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
93
    )
Tri Dao's avatar
Tri Dao committed
94
    # dresidualmat is None if not has_residual
Tri Dao's avatar
Tri Dao committed
95
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
96
        return dx0mat, dresidualmat, dgamma, dbeta
Tri Dao's avatar
Tri Dao committed
97
98
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
99
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
100
101


102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def _dropout_add_layer_norm_parallel_residual_forward(
    x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
    epsilon, residual_in_fp32=False, is_rms_norm=False
):
    """ Assume that arguments are contiguous
    """
    hidden_size = gamma0.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
    z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
        x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
        None, residual_in_fp32, is_rms_norm
    )
    # dmask0 and dmask1 are None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
    return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma


def _dropout_add_layer_norm_parallel_residual_backward(
    dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
    dropout_p, has_x1, has_residual, is_rms_norm=False
):
    """ Assume that arguments are contiguous
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + residual was not returned in the fwd).
    """
    hidden_size = gamma0.numel()
    xmat = x.view((-1, hidden_size))
    dz0mat = dz0.view(xmat.shape)
    dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
    dxmat = dx.view(xmat.shape) if dx is not None else None
    dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
        dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
        dropout_p, has_x1, has_residual, is_rms_norm
    )
    # dresidualmat is None if not has_residual
    return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1


Tri Dao's avatar
Tri Dao committed
142
class DropoutAddLayerNormFn(torch.autograd.Function):
143
    @staticmethod
Tri Dao's avatar
Tri Dao committed
144
    def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
145
                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
146
        x0 = x0.contiguous()
Tri Dao's avatar
Tri Dao committed
147
        residual = residual.contiguous() if residual is not None else None
148
        gamma = gamma.contiguous()
Tri Dao's avatar
Tri Dao committed
149
        beta = beta.contiguous() if beta is not None else None
150
        rowscale = rowscale.contiguous() if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
151
        colscale = colscale.contiguous() if colscale is not None else None
152
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
Tri Dao's avatar
Tri Dao committed
153
            x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
154
            residual_in_fp32, is_rms_norm
155
        )
Tri Dao's avatar
Tri Dao committed
156
157
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
158
        ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
Tri Dao's avatar
Tri Dao committed
159
        ctx.prenorm = prenorm
160
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
161
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
162
163
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
164
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
165
166
            return (zmat.view(x0.shape) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape)))
167
168
169
170
        else:
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
171
172
            return ((zmat.view(x0.shape), dmask) if not prenorm
                    else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
173
174
175
176
177

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
Tri Dao's avatar
Tri Dao committed
178
179
180
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
181
182
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
183
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
Tri Dao's avatar
Tri Dao committed
184
185
            dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
            ctx.is_rms_norm
186
187
        )
        dx0 = dx0mat.view(x.shape)
Tri Dao's avatar
Tri Dao committed
188
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
Tri Dao's avatar
Tri Dao committed
189
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
190
191
        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
                None, None, None, None, None)
192
193


194
195
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
196
    def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
197
198
                rowscale_const, out_numrows, residual_in_fp32=False,
                prenorm=False, is_rms_norm=False, return_dmask=False):
199
        x0 = x0.contiguous()
Tri Dao's avatar
Tri Dao committed
200
        residual = residual.contiguous() if residual is not None else None
201
        gamma = gamma.contiguous()
Tri Dao's avatar
Tri Dao committed
202
        beta = beta.contiguous() if beta is not None else None
203
204
        colscale = colscale.contiguous() if colscale is not None else None
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
Tri Dao's avatar
Tri Dao committed
205
            x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
206
            rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
207
208
209
210
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        x_shape = (-1, *x0.shape[1:])
211
        ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
212
213
214
215
216
                              x0_subset, out_subset)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.rowscale_const = rowscale_const
        ctx.x0_numrows = x0.shape[:-1].numel()
Tri Dao's avatar
Tri Dao committed
217
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
218
219
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        z_shape = (-1, *x0.shape[1:])
        if not return_dmask:
            return (zmat.view(z_shape) if not prenorm
                    else (zmat.view(z_shape), xmat.view(x0.shape)))
        else:
            z = zmat.view(z_shape)
            dmask = (dmask.view(x0.shape) if dropout_p > 0.
                     else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask)
            return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
        dz = dz.contiguous()  # this happens!
        dx = args[0].contiguous() if ctx.prenorm else None
        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
240
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
241
            dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
Tri Dao's avatar
Tri Dao committed
242
            ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
243
244
        )
        dx0 = dx0mat.view(-1, *x.shape[1:])
Tri Dao's avatar
Tri Dao committed
245
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
246
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
247
248
        return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
                None, None, None, None, None, None, None, None)
249
250


251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
                residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
        x0 = x0.contiguous()
        x1 = x1.contiguous() if x1 is not None else None
        residual = residual.contiguous() if residual is not None else None
        gamma0 = gamma0.contiguous()
        beta0 = beta0.contiguous() if beta0 is not None else None
        gamma1 = gamma1.contiguous() if gamma1 is not None else None
        beta1 = beta1.contiguous() if beta1 is not None else None
        z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
            x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
            residual_in_fp32, is_rms_norm
        )
        ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.has_x1 = x1 is not None
        ctx.has_residual = residual is not None
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta0 is not None
        z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
        if not return_dmask:
            return z if not prenorm else (*z, xmat.view(x0.shape))
        else:
            dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
                      else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
            ctx.mark_non_differentiable(dmask0)
            ctx.mark_non_differentiable(dmask1)
            return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)

    @staticmethod
    def backward(ctx, dz0, dz1, *args):
        dz0 = dz0.contiguous()  # this happens!
        dz1 = dz1.contiguous() if dz1 is not None else None
        dx = args[0].contiguous() if ctx.prenorm else None
        x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
        dropout_p = ctx.dropout_p
        has_x1 = ctx.has_x1
        has_residual = ctx.has_residual
        dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
            dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
            has_residual, ctx.is_rms_norm
        )
        dx0 = dx0mat.view(x.shape)
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
        return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
                dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)


Tri Dao's avatar
Tri Dao committed
305
306
307
308
def layer_norm(x, weight, bias, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)


Tri Dao's avatar
Tri Dao committed
309
310
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                           layerscale=None, prenorm=False, residual_in_fp32=False,
311
                           return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
312
313
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
314
    """
Tri Dao's avatar
Tri Dao committed
315
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
316
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
317
        False, return_dropout_mask
Tri Dao's avatar
Tri Dao committed
318
    )
319
320


Tri Dao's avatar
Tri Dao committed
321
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
322
323
324
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
325
326
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
327
328
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
329
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
330
        rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
331
332
333
334
335
336
337
338
339
340
341
342
343
    )


def dropout_add_layer_norm_parallel_residual(
    x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
    residual_in_fp32=False, return_dropout_mask=False
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
        False, return_dropout_mask
344
345
346
    )


347
class DropoutAddLayerNorm(torch.nn.Module):
348
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.epsilon = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

Tri Dao's avatar
Tri Dao committed
364
365
    def forward(self, x0, residual=None):
        return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
366
367
                                      self.p if self.training else 0.0, self.epsilon,
                                      prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)