layer_norm.py 21.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
# Copyright (c) 2022, Tri Dao.
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
Tri Dao's avatar
Tri Dao committed
3

Tri Dao's avatar
Tri Dao committed
4
import dropout_layer_norm
5
6
7
8
import torch
from torch.nn import init


9
def maybe_align(x, alignment_in_bytes=16):
Tri Dao's avatar
Tri Dao committed
10
    """Assume that x already has last dim divisible by alignment_in_bytes"""
11
12
13
14
15
    # TD [2023-07-04] I'm not 100% sure that clone will align the memory
    # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
    return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()


Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
23
24
25
26
27
28
def _dropout_add_layer_norm_forward(
    x0,
    residual,
    gamma,
    beta,
    rowscale,
    colscale,
    dropout_p,
    epsilon,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes"""
29
30
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
31
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
32
33
    rowscale = rowscale.view(-1) if rowscale is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        x0mat,
        residualmat,
        gamma,
        beta,
        rowscale,
        colscale,
        None,
        None,
        dropout_p,
        epsilon,
        1.0,
        0,
        None,
        residual_in_fp32,
        is_rms_norm,
49
50
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
51
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
52
53
54
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def _dropout_add_layer_norm_backward(
    dz,
    dx,
    x,
    x0,
    dmask,
    mu,
    rsigma,
    gamma,
    rowscale,
    colscale,
    dropout_p,
    has_residual,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes
Tri Dao's avatar
Tri Dao committed
71
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
72
    (x = drop(x0) + residual was not returned in the fwd).
Tri Dao's avatar
Tri Dao committed
73
    x0 must not be None if we have colscale.
74
75
76
77
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(xmat.shape)
Tri Dao's avatar
Tri Dao committed
78
79
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
80
    rowscale = rowscale.view(-1) if rowscale is not None else None
Tri Dao's avatar
Tri Dao committed
81
    if colscale is not None:
Tri Dao's avatar
Tri Dao committed
82
        assert x0 is not None, "x0 is required to compute the gradient of colscale"
Tri Dao's avatar
Tri Dao committed
83
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
Tri Dao's avatar
Tri Dao committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        dzmat,
        dxmat,
        xmat,
        x0mat,
        dmask,
        mu,
        rsigma,
        gamma,
        rowscale,
        colscale,
        None,
        None,
        dropout_p,
        1.0,
        0,
        has_residual,
        is_rms_norm,
101
    )
Tri Dao's avatar
Tri Dao committed
102
    # dresidualmat is None if not has_residual
103
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
104
        return dx0mat, dresidualmat, dgamma, dbeta
105
106
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
107
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
108
109


Tri Dao's avatar
Tri Dao committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def _dropout_add_layer_norm_subset_forward(
    x0,
    residual,
    gamma,
    beta,
    colscale,
    x0_subset,
    out_subset,
    dropout_p,
    epsilon,
    rowscale_const,
    out_numrows,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes"""
126
127
    hidden_size = gamma.numel()
    x0mat = x0.view((-1, hidden_size))
Tri Dao's avatar
Tri Dao committed
128
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
129
130
131
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
Tri Dao's avatar
Tri Dao committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        x0mat,
        residualmat,
        gamma,
        beta,
        None,
        colscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        None,
        residual_in_fp32,
        is_rms_norm,
147
148
    )
    # dmask is None if dropout_p == 0.0
Tri Dao's avatar
Tri Dao committed
149
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
150
151
152
    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma


Tri Dao's avatar
Tri Dao committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def _dropout_add_layer_norm_subset_backward(
    dz,
    dx,
    x,
    x0,
    dmask,
    mu,
    rsigma,
    gamma,
    colscale,
    x0_subset,
    out_subset,
    dropout_p,
    rowscale_const,
    x0_numrows,
    has_residual,
    is_rms_norm=False,
):
    """Assume that arguments are contiguous and aligned to 16 bytes
172
    dx == None means that it was a post-norm architecture
Tri Dao's avatar
Tri Dao committed
173
    (x = drop(x0) + residual was not returned in the fwd).
174
175
176
177
178
179
180
181
182
183
    x0 must not be None if we have colscale.
    """
    hidden_size = gamma.numel()
    xmat = x.view((-1, hidden_size))
    dzmat = dz.view(-1, hidden_size)
    dxmat = dx.view(xmat.shape) if dx is not None else None
    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
    x0_subset = x0_subset.view(-1) if x0_subset is not None else None
    out_subset = out_subset.view(-1) if out_subset is not None else None
    if colscale is not None:
Tri Dao's avatar
Tri Dao committed
184
        assert x0 is not None, "x0 is required to compute the gradient of colscale"
Tri Dao's avatar
Tri Dao committed
185
    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        dzmat,
        dxmat,
        xmat,
        x0mat,
        dmask,
        mu,
        rsigma,
        gamma,
        None,
        colscale,
        x0_subset,
        out_subset,
        dropout_p,
        rowscale_const,
        x0_numrows,
        has_residual,
        is_rms_norm,
203
    )
Tri Dao's avatar
Tri Dao committed
204
    # dresidualmat is None if not has_residual
Tri Dao's avatar
Tri Dao committed
205
    if colscale is None:
Tri Dao's avatar
Tri Dao committed
206
        return dx0mat, dresidualmat, dgamma, dbeta
Tri Dao's avatar
Tri Dao committed
207
208
    else:
        dcolscale = rest[0]
Tri Dao's avatar
Tri Dao committed
209
        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
210
211


212
def _dropout_add_layer_norm_parallel_residual_forward(
Tri Dao's avatar
Tri Dao committed
213
214
215
216
217
218
219
220
221
222
223
    x0,
    x1,
    residual,
    gamma0,
    beta0,
    gamma1,
    beta1,
    dropout_p,
    epsilon,
    residual_in_fp32=False,
    is_rms_norm=False,
224
):
Tri Dao's avatar
Tri Dao committed
225
    """Assume that arguments are contiguous and aligned to 16 bytes"""
226
227
228
229
    hidden_size = gamma0.numel()
    x0mat = x0.view((-1, hidden_size))
    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
    residualmat = residual.view((-1, hidden_size)) if residual is not None else None
Tri Dao's avatar
Tri Dao committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    (
        z0mat,
        z1mat,
        xmat,
        dmask0,
        dmask1,
        mu,
        rsigma,
    ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
        x0mat,
        x1mat,
        residualmat,
        gamma0,
        beta0,
        gamma1,
        beta1,
        dropout_p,
        epsilon,
        None,
        residual_in_fp32,
        is_rms_norm,
251
252
253
254
255
256
257
    )
    # dmask0 and dmask1 are None if dropout_p == 0.0
    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
    return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma


def _dropout_add_layer_norm_parallel_residual_backward(
Tri Dao's avatar
Tri Dao committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    dz0,
    dz1,
    dx,
    x,
    dmask0,
    dmask1,
    mu,
    rsigma,
    gamma0,
    gamma1,
    dropout_p,
    has_x1,
    has_residual,
    is_rms_norm=False,
272
):
Tri Dao's avatar
Tri Dao committed
273
    """Assume that arguments are contiguous and aligned to 16 bytes
274
275
276
277
278
279
280
281
    dx == None means that it was a post-norm architecture
    (x = drop(x0) + residual was not returned in the fwd).
    """
    hidden_size = gamma0.numel()
    xmat = x.view((-1, hidden_size))
    dz0mat = dz0.view(xmat.shape)
    dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
    dxmat = dx.view(xmat.shape) if dx is not None else None
Tri Dao's avatar
Tri Dao committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    (
        dx0mat,
        dx1mat,
        dresidualmat,
        dgamma0,
        dbeta0,
        dgamma1,
        dbeta1,
        *rest,
    ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
        dz0mat,
        dz1mat,
        dxmat,
        xmat,
        dmask0,
        dmask1,
        mu,
        rsigma,
        gamma0,
        gamma1,
        dropout_p,
        has_x1,
        has_residual,
        is_rms_norm,
306
307
308
309
310
    )
    # dresidualmat is None if not has_residual
    return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1


Tri Dao's avatar
Tri Dao committed
311
class DropoutAddLayerNormFn(torch.autograd.Function):
312
    @staticmethod
Tri Dao's avatar
Tri Dao committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    def forward(
        ctx,
        x0,
        residual,
        gamma,
        beta,
        rowscale,
        colscale,
        dropout_p,
        epsilon,
        residual_in_fp32=False,
        prenorm=False,
        is_rms_norm=False,
        return_dmask=False,
    ):
328
329
330
331
332
333
        x0 = maybe_align(x0.contiguous(), 16)
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma = maybe_align(gamma.contiguous(), 16)
        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
        rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
334
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
Tri Dao's avatar
Tri Dao committed
335
336
337
338
339
340
341
342
343
344
            x0,
            residual,
            gamma,
            beta,
            rowscale,
            colscale,
            dropout_p,
            epsilon,
            residual_in_fp32,
            is_rms_norm,
345
        )
Tri Dao's avatar
Tri Dao committed
346
347
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
348
349
350
        ctx.save_for_backward(
            xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
        )
Tri Dao's avatar
Tri Dao committed
351
        ctx.prenorm = prenorm
352
        ctx.dropout_p = dropout_p
Tri Dao's avatar
Tri Dao committed
353
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
354
355
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
356
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
357
358
359
            return (
                zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
            )
360
        else:
Tri Dao's avatar
Tri Dao committed
361
362
363
364
365
            dmask = (
                dmask.view(x0.shape)
                if dropout_p > 0.0
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
366
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
367
368
369
370
371
            return (
                (zmat.view(x0.shape), dmask)
                if not prenorm
                else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
            )
372
373
374
375

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
376
377
        dz = maybe_align(dz.contiguous(), 16)  # this happens!
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
Tri Dao's avatar
Tri Dao committed
378
379
        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
        # x0 is None if colscale is None
380
381
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
382
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
Tri Dao's avatar
Tri Dao committed
383
384
385
386
387
388
389
390
391
392
393
394
395
            dz,
            dx,
            x,
            x0,
            dmask,
            mu,
            rsigma,
            gamma,
            rowscale,
            colscale,
            dropout_p,
            has_residual,
            ctx.is_rms_norm,
396
397
        )
        dx0 = dx0mat.view(x.shape)
Tri Dao's avatar
Tri Dao committed
398
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
Tri Dao's avatar
Tri Dao committed
399
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
        return (
            dx0,
            dresidual,
            dgamma,
            dbeta if ctx.has_beta else None,
            None,
            dcolscale,
            None,
            None,
            None,
            None,
            None,
            None,
        )
414
415


416
417
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    def forward(
        ctx,
        x0,
        residual,
        gamma,
        beta,
        colscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        residual_in_fp32=False,
        prenorm=False,
        is_rms_norm=False,
        return_dmask=False,
    ):
436
437
438
439
440
        x0 = maybe_align(x0.contiguous(), 16)
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma = maybe_align(gamma.contiguous(), 16)
        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
441
        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
Tri Dao's avatar
Tri Dao committed
442
443
444
445
446
447
448
449
450
451
452
453
454
            x0,
            residual,
            gamma,
            beta,
            colscale,
            x0_subset,
            out_subset,
            dropout_p,
            epsilon,
            rowscale_const,
            out_numrows,
            residual_in_fp32,
            is_rms_norm,
455
456
457
458
        )
        # Only need to save x0 if we need to compute gradient wrt colscale
        x0_saved = x0 if colscale is not None else None
        x_shape = (-1, *x0.shape[1:])
Tri Dao's avatar
Tri Dao committed
459
460
461
        ctx.save_for_backward(
            xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
        )
462
463
464
465
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.rowscale_const = rowscale_const
        ctx.x0_numrows = x0.shape[:-1].numel()
Tri Dao's avatar
Tri Dao committed
466
        ctx.has_residual = residual is not None
Tri Dao's avatar
Tri Dao committed
467
468
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta is not None
469
470
        z_shape = (-1, *x0.shape[1:])
        if not return_dmask:
Tri Dao's avatar
Tri Dao committed
471
            return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
472
473
        else:
            z = zmat.view(z_shape)
Tri Dao's avatar
Tri Dao committed
474
475
476
477
478
            dmask = (
                dmask.view(x0.shape)
                if dropout_p > 0.0
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
479
            ctx.mark_non_differentiable(dmask)
Tri Dao's avatar
Tri Dao committed
480
            return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
481
482
483
484

    @staticmethod
    def backward(ctx, dz, *args):
        # assert dz.is_contiguous()
485
486
        dz = maybe_align(dz.contiguous(), 16)  # this happens!
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
487
488
489
490
        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
        # x0 is None if colscale is None
        dropout_p = ctx.dropout_p
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
491
        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
Tri Dao's avatar
Tri Dao committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            dz,
            dx,
            x,
            x0,
            dmask,
            mu,
            rsigma,
            gamma,
            colscale,
            x0_subset,
            out_subset,
            dropout_p,
            ctx.rowscale_const,
            ctx.x0_numrows,
            has_residual,
            ctx.is_rms_norm,
508
509
        )
        dx0 = dx0mat.view(-1, *x.shape[1:])
Tri Dao's avatar
Tri Dao committed
510
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
511
        dcolscale = rest[0] if colscale is not None else None
Tri Dao's avatar
Tri Dao committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        return (
            dx0,
            dresidual,
            dgamma,
            dbeta if ctx.has_beta else None,
            dcolscale,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )
529
530


531
532
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
    @staticmethod
Tri Dao's avatar
Tri Dao committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
    def forward(
        ctx,
        x0,
        x1,
        residual,
        gamma0,
        beta0,
        gamma1,
        beta1,
        dropout_p,
        epsilon,
        residual_in_fp32=False,
        prenorm=False,
        is_rms_norm=False,
        return_dmask=False,
    ):
549
550
551
552
553
554
555
        x0 = maybe_align(x0.contiguous(), 16)
        x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
        gamma0 = maybe_align(gamma0.contiguous(), 16)
        beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
        gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
        beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
Tri Dao's avatar
Tri Dao committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        (
            z0mat,
            z1mat,
            xmat,
            dmask0,
            dmask1,
            mu,
            rsigma,
        ) = _dropout_add_layer_norm_parallel_residual_forward(
            x0,
            x1,
            residual,
            gamma0,
            beta0,
            gamma1,
            beta1,
            dropout_p,
            epsilon,
            residual_in_fp32,
            is_rms_norm,
576
577
578
579
580
581
582
583
584
585
586
587
        )
        ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
        ctx.prenorm = prenorm
        ctx.dropout_p = dropout_p
        ctx.has_x1 = x1 is not None
        ctx.has_residual = residual is not None
        ctx.is_rms_norm = is_rms_norm
        ctx.has_beta = beta0 is not None
        z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)
        if not return_dmask:
            return z if not prenorm else (*z, xmat.view(x0.shape))
        else:
Tri Dao's avatar
Tri Dao committed
588
589
590
591
592
593
594
595
596
597
            dmask0 = (
                dmask0.view(x0.shape)
                if dropout_p > 0.0
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
            dmask1 = (
                dmask1.view(x0.shape)
                if dropout_p > 0.0 and x1 is not None
                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
            )
598
599
            ctx.mark_non_differentiable(dmask0)
            ctx.mark_non_differentiable(dmask1)
Tri Dao's avatar
Tri Dao committed
600
601
602
            return (
                (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
            )
603
604
605

    @staticmethod
    def backward(ctx, dz0, dz1, *args):
606
607
608
        dz0 = maybe_align(dz0.contiguous(), 16)  # this happens!
        dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
609
610
611
612
        x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
        dropout_p = ctx.dropout_p
        has_x1 = ctx.has_x1
        has_residual = ctx.has_residual
Tri Dao's avatar
Tri Dao committed
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        (
            dx0mat,
            dx1mat,
            dresidualmat,
            dgamma0,
            dbeta0,
            dgamma1,
            dbeta1,
        ) = _dropout_add_layer_norm_parallel_residual_backward(
            dz0,
            dz1,
            dx,
            x,
            dmask0,
            dmask1,
            mu,
            rsigma,
            gamma0,
            gamma1,
            dropout_p,
            has_x1,
            has_residual,
            ctx.is_rms_norm,
636
637
638
639
        )
        dx0 = dx0mat.view(x.shape)
        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
Tri Dao's avatar
Tri Dao committed
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        return (
            dx0,
            dx1,
            dresidual,
            dgamma0,
            dbeta0 if ctx.has_beta else None,
            dgamma1,
            dbeta1 if ctx.has_beta else None,
            None,
            None,
            None,
            None,
            None,
            None,
        )
655
656


Tri Dao's avatar
Tri Dao committed
657
658
659
660
def layer_norm(x, weight, bias, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)


Tri Dao's avatar
Tri Dao committed
661
662
663
664
665
666
667
668
669
670
671
672
673
def dropout_add_layer_norm(
    x0,
    residual,
    weight,
    bias,
    dropout_p,
    epsilon,
    rowscale=None,
    layerscale=None,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
Tri Dao's avatar
Tri Dao committed
674
675
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
676
    """
Tri Dao's avatar
Tri Dao committed
677
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
678
679
680
681
682
683
684
685
686
687
688
689
        x0,
        residual,
        weight,
        bias,
        rowscale,
        layerscale,
        dropout_p,
        epsilon,
        residual_in_fp32,
        prenorm,
        False,
        return_dropout_mask,
Tri Dao's avatar
Tri Dao committed
690
    )
691
692


Tri Dao's avatar
Tri Dao committed
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
def dropout_add_layer_norm_subset(
    x0,
    residual,
    weight,
    bias,
    dropout_p,
    epsilon,
    layerscale=None,
    x0_subset=None,
    out_subset=None,
    rowscale_const=1.0,
    out_numrows=0,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
Tri Dao's avatar
Tri Dao committed
709
710
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
711
712
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
        x0,
        residual,
        weight,
        bias,
        layerscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        residual_in_fp32,
        prenorm,
        False,
        return_dropout_mask,
728
729
730
731
    )


def dropout_add_layer_norm_parallel_residual(
Tri Dao's avatar
Tri Dao committed
732
733
734
735
736
737
738
739
740
741
742
743
    x0,
    x1,
    residual,
    weight0,
    bias0,
    weight1,
    bias1,
    dropout_p,
    epsilon,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
744
745
746
747
748
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
Tri Dao's avatar
Tri Dao committed
749
750
751
752
753
754
755
756
757
758
759
760
761
        x0,
        x1,
        residual,
        weight0,
        bias0,
        weight1,
        bias1,
        dropout_p,
        epsilon,
        residual_in_fp32,
        prenorm,
        False,
        return_dropout_mask,
762
763
764
    )


765
class DropoutAddLayerNorm(torch.nn.Module):
Tri Dao's avatar
Tri Dao committed
766
767
768
769
770
771
772
773
774
775
776
    def __init__(
        self,
        hidden_size,
        prenorm=False,
        p=0.0,
        eps=1e-5,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
777
778
779
        super().__init__()
        self.prenorm = prenorm
        self.p = p
Tri Dao's avatar
Tri Dao committed
780
        self.eps = eps
781
782
783
784
785
786
787
788
789
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)
        init.zeros_(self.bias)

Tri Dao's avatar
Tri Dao committed
790
    def forward(self, x0, residual=None):
Tri Dao's avatar
Tri Dao committed
791
792
793
794
795
796
797
798
799
800
        return dropout_add_layer_norm(
            x0,
            residual,
            self.weight,
            self.bias,
            self.p if self.training else 0.0,
            self.eps,
            prenorm=self.prenorm,
            residual_in_fp32=self.residual_in_fp32,
        )