layernorm.py 25.4 KB
Newer Older
1
# Copyright (c) 2024, Tri Dao.
2
# Implement dropout + residual + layer_norm / rms_norm.
3
4
5
6
7
8
9
10
11
12

# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.

import math

import torch
import torch.nn.functional as F
13
from torch.cuda.amp import custom_fwd, custom_bwd
14
15
16
17
18

import triton
import triton.language as tl


19
20
21
22
23
24
25
def layer_norm_ref(
    x,
    weight,
    bias,
    residual=None,
    eps=1e-6,
    dropout_p=0.0,
26
    rowscale=None,
27
28
29
30
    prenorm=False,
    dropout_mask=None,
    upcast=False,
):
31
32
33
34
35
36
37
    dtype = x.dtype
    if upcast:
        weight = weight.float()
        bias = bias.float() if bias is not None else None
    if upcast:
        x = x.float()
        residual = residual.float() if residual is not None else residual
38
39
    if rowscale is not None:
        x = x * rowscale[..., None]
40
41
42
43
44
    if dropout_p > 0.0:
        if dropout_mask is not None:
            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
        else:
            x = F.dropout(x, p=dropout_p)
45
46
    if residual is not None:
        x = (x + residual).to(x.dtype)
47
48
49
50
    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
        dtype
    )
    return out if not prenorm else (out, x)
51
52


53
54
55
56
57
58
59
def rms_norm_ref(
    x,
    weight,
    bias,
    residual=None,
    eps=1e-6,
    dropout_p=0.0,
60
    rowscale=None,
61
62
63
64
    prenorm=False,
    dropout_mask=None,
    upcast=False,
):
65
66
67
68
69
70
71
    dtype = x.dtype
    if upcast:
        weight = weight.float()
        bias = bias.float() if bias is not None else None
    if upcast:
        x = x.float()
        residual = residual.float() if residual is not None else residual
72
73
    if rowscale is not None:
        x = x * rowscale[..., None]
74
75
76
77
78
    if dropout_p > 0.0:
        if dropout_mask is not None:
            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
        else:
            x = F.dropout(x, p=dropout_p)
79
80
81
82
83
    if residual is not None:
        x = (x + residual).to(x.dtype)
    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
    out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
    out = out.to(dtype)
84
    return out if not prenorm else (out, x)
85
86
87
88
89
90
91
92
93
94
95


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=1),
        triton.Config({}, num_warps=2),
        triton.Config({}, num_warps=4),
        triton.Config({}, num_warps=8),
        triton.Config({}, num_warps=16),
        triton.Config({}, num_warps=32),
    ],
96
    key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
97
98
99
100
101
102
103
104
105
106
107
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
    X,  # pointer to the input
    Y,  # pointer to the output
    W,  # pointer to the weights
    B,  # pointer to the biases
    RESIDUAL,  # pointer to the residual
    RESIDUAL_OUT,  # pointer to the residual
108
    ROWSCALE,
109
110
    SEEDS,  # Dropout seeds for each row
    DROPOUT_MASK,
111
112
113
114
115
116
117
118
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride_x_row,  # how much to increase the pointer when moving by 1 row
    stride_y_row,
    stride_res_row,
    stride_res_out_row,
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
119
    dropout_p,  # Dropout probability
120
121
122
    IS_RMS_NORM: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_RESIDUAL: tl.constexpr,
123
    STORE_RESIDUAL_OUT: tl.constexpr,
124
    HAS_BIAS: tl.constexpr,
125
126
    HAS_DROPOUT: tl.constexpr,
    STORE_DROPOUT_MASK: tl.constexpr,
127
    HAS_ROWSCALE: tl.constexpr,
128
129
130
131
132
133
134
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
135
    if STORE_RESIDUAL_OUT:
136
137
138
        RESIDUAL_OUT += row * stride_res_out_row
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
139
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
140
141
142
    if HAS_ROWSCALE:
        rowscale = tl.load(ROWSCALE + row).to(tl.float32)
        x *= rowscale
143
144
145
146
147
148
149
    if HAS_DROPOUT:
        # Compute dropout mask
        # 7 rounds is good enough, and reduces register pressure
        keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
        x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
        if STORE_DROPOUT_MASK:
            tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
150
    if HAS_RESIDUAL:
151
        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
152
        x += residual
153
    if STORE_RESIDUAL_OUT:
154
155
156
157
        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
    if not IS_RMS_NORM:
        mean = tl.sum(x, axis=0) / N
        tl.store(Mean + row, mean)
158
        xbar = tl.where(cols < N, x - mean, 0.0)
159
160
        var = tl.sum(xbar * xbar, axis=0) / N
    else:
161
        xbar = tl.where(cols < N, x, 0.0)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        var = tl.sum(xbar * xbar, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation
    mask = cols < N
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    if HAS_BIAS:
        b = tl.load(B + cols, mask=mask).to(tl.float32)
    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
    y = x_hat * w + b if HAS_BIAS else x_hat * w
    # Write output
    tl.store(Y + cols, y, mask=mask)


176
def _layer_norm_fwd(
177
178
179
180
181
182
    x,
    weight,
    bias,
    eps,
    residual=None,
    dropout_p=0.0,
183
    rowscale=None,
184
185
186
187
    out_dtype=None,
    residual_dtype=None,
    is_rms_norm=False,
    return_dropout_mask=False,
188
):
189
190
    if residual is not None:
        residual_dtype = residual.dtype
191
192
193
194
195
196
197
198
199
200
    M, N = x.shape
    assert x.stride(-1) == 1
    if residual is not None:
        assert residual.stride(-1) == 1
        assert residual.shape == (M, N)
    assert weight.shape == (N,)
    assert weight.stride(-1) == 1
    if bias is not None:
        assert bias.stride(-1) == 1
        assert bias.shape == (N,)
201
202
203
    if rowscale is not None:
        assert rowscale.is_contiguous()
        assert rowscale.shape == (M,)
204
    # allocate output
205
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
206
    assert y.stride(-1) == 1
207
208
209
210
    if (
        residual is not None
        or (residual_dtype is not None and residual_dtype != x.dtype)
        or dropout_p > 0.0
211
        or rowscale is not None
212
213
214
215
    ):
        residual_out = torch.empty(
            M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
        )
216
217
218
        assert residual_out.stride(-1) == 1
    else:
        residual_out = None
219
220
    mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None
    rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
221
222
223
224
225
226
227
228
    if dropout_p > 0.0:
        seeds = torch.randint(2**32, (M,), device=x.device, dtype=torch.int64)
    else:
        seeds = None
    if return_dropout_mask and dropout_p > 0.0:
        dropout_mask = torch.empty_like(x, dtype=torch.bool)
    else:
        dropout_mask = None
229
230
231
232
233
234
235
    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    # heuristics for number of warps
    with torch.cuda.device(x.device.index):
236
237
238
239
240
241
242
        _layer_norm_fwd_1pass_kernel[(M,)](
            x,
            y,
            weight,
            bias,
            residual,
            residual_out,
243
            rowscale,
244
245
            seeds,
            dropout_mask,
246
247
248
249
250
251
252
253
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
            N,
            eps,
254
            dropout_p,
255
256
257
258
259
            is_rms_norm,
            BLOCK_N,
            residual is not None,
            residual_out is not None,
            bias is not None,
260
261
            dropout_p > 0.0,
            dropout_mask is not None,
262
            rowscale is not None,
263
        )
264
265
    # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
    return y, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask
266
267
268
269
270
271
272
273
274
275
276


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=1),
        triton.Config({}, num_warps=2),
        triton.Config({}, num_warps=4),
        triton.Config({}, num_warps=8),
        triton.Config({}, num_warps=16),
        triton.Config({}, num_warps=32),
    ],
277
    key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
278
279
280
281
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
282
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
283
284
285
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
286
287
288
289
    X,  # pointer to the input
    W,  # pointer to the weights
    B,  # pointer to the biases
    Y,  # pointer to the output to be recomputed
290
291
292
293
294
295
    DY,  # pointer to the output gradient
    DX,  # pointer to the input gradient
    DW,  # pointer to the partial sum of weights gradient
    DB,  # pointer to the partial sum of biases gradient
    DRESIDUAL,
    DRESIDUAL_IN,
296
    ROWSCALE,
297
    SEEDS,
298
299
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
300
301
302
303
304
305
306
307
308
    stride_x_row,  # how much to increase the pointer when moving by 1 row
    stride_y_row,
    stride_dy_row,
    stride_dx_row,
    stride_dres_row,
    stride_dres_in_row,
    M,  # number of rows in X
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
309
    dropout_p,
310
311
312
313
314
315
    rows_per_program,
    IS_RMS_NORM: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_DRESIDUAL: tl.constexpr,
    STORE_DRESIDUAL: tl.constexpr,
    HAS_BIAS: tl.constexpr,
316
    HAS_DROPOUT: tl.constexpr,
317
    HAS_ROWSCALE: tl.constexpr,
318
319
320
321
322
    RECOMPUTE_OUTPUT: tl.constexpr,
):
    # Map the program id to the elements of X, DX, and DY it should compute.
    row_block_id = tl.program_id(0)
    row_start = row_block_id * rows_per_program
323
324
    if row_start >= M:
        return
325
326
327
328
329
330
331
332
333
334
335
336
337
    cols = tl.arange(0, BLOCK_N)
    mask = cols < N
    X += row_start * stride_x_row
    if HAS_DRESIDUAL:
        DRESIDUAL += row_start * stride_dres_row
    if STORE_DRESIDUAL:
        DRESIDUAL_IN += row_start * stride_dres_in_row
    DY += row_start * stride_dy_row
    DX += row_start * stride_dx_row
    if RECOMPUTE_OUTPUT:
        Y += row_start * stride_y_row
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    if RECOMPUTE_OUTPUT and HAS_BIAS:
338
        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
339
340
341
342
343
344
345
346
347
348
349
350
351
    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
    if HAS_BIAS:
        db = tl.zeros((BLOCK_N,), dtype=tl.float32)
    row_end = min((row_block_id + 1) * rows_per_program, M)
    for row in range(row_start, row_end):
        # Load data to SRAM
        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
        if not IS_RMS_NORM:
            mean = tl.load(Mean + row)
        rstd = tl.load(Rstd + row)
        # Compute dx
        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
352
        xhat = tl.where(mask, xhat, 0.0)
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        if RECOMPUTE_OUTPUT:
            y = xhat * w + b if HAS_BIAS else xhat * w
            tl.store(Y + cols, y, mask=mask)
        wdy = w * dy
        dw += dy * xhat
        if HAS_BIAS:
            db += dy
        if not IS_RMS_NORM:
            c1 = tl.sum(xhat * wdy, axis=0) / N
            c2 = tl.sum(wdy, axis=0) / N
            dx = (wdy - (xhat * c1 + c2)) * rstd
        else:
            c1 = tl.sum(xhat * wdy, axis=0) / N
            dx = (wdy - xhat * c1) * rstd
        if HAS_DRESIDUAL:
            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
            dx += dres
        # Write dx
        if STORE_DRESIDUAL:
            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
373
374
375
        if HAS_DROPOUT:
            keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
            dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
376
377
378
        if HAS_ROWSCALE:
            rowscale = tl.load(ROWSCALE + row).to(tl.float32)
            dx *= rowscale
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        tl.store(DX + cols, dx, mask=mask)

        X += stride_x_row
        if HAS_DRESIDUAL:
            DRESIDUAL += stride_dres_row
        if STORE_DRESIDUAL:
            DRESIDUAL_IN += stride_dres_in_row
        if RECOMPUTE_OUTPUT:
            Y += stride_y_row
        DY += stride_dy_row
        DX += stride_dx_row
    tl.store(DW + row_block_id * N + cols, dw, mask=mask)
    if HAS_BIAS:
        tl.store(DB + row_block_id * N + cols, db, mask=mask)


395
396
397
398
399
400
401
402
403
def _layer_norm_bwd(
    dy,
    x,
    weight,
    bias,
    eps,
    mean,
    rstd,
    dresidual=None,
404
405
    seeds=None,
    dropout_p=0.0,
406
    rowscale=None,
407
408
409
410
411
    has_residual=False,
    is_rms_norm=False,
    x_dtype=None,
    recompute_output=False,
):
412
413
414
415
416
417
418
419
420
421
422
423
    M, N = x.shape
    assert x.stride(-1) == 1
    assert dy.stride(-1) == 1
    assert dy.shape == (M, N)
    if dresidual is not None:
        assert dresidual.stride(-1) == 1
        assert dresidual.shape == (M, N)
    assert weight.shape == (N,)
    assert weight.stride(-1) == 1
    if bias is not None:
        assert bias.stride(-1) == 1
        assert bias.shape == (N,)
424
425
426
    if seeds is not None:
        assert seeds.is_contiguous()
        assert seeds.shape == (M,)
427
428
429
    if rowscale is not None:
        assert rowscale.is_contiguous()
        assert rowscale.shape == (M,)
430
    # allocate output
431
432
433
434
435
    dx = (
        torch.empty_like(x)
        if x_dtype is None
        else torch.empty(M, N, dtype=x_dtype, device=x.device)
    )
436
    dresidual_in = (
437
438
439
        torch.empty_like(x)
        if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None)
        else None
440
    )
441
442
443
444
445
446
447
448
449
    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None

    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
    _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
450
451
452
453
454
    _db = (
        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
        if bias is not None
        else None
    )
455
456
457
    rows_per_program = math.ceil(M / sm_count)
    grid = (sm_count,)
    with torch.cuda.device(x.device.index):
458
459
460
461
462
463
464
465
466
467
468
        _layer_norm_bwd_kernel[grid](
            x,
            weight,
            bias,
            y,
            dy,
            dx,
            _dw,
            _db,
            dresidual,
            dresidual_in,
469
            rowscale,
470
            seeds,
471
472
473
474
475
476
477
478
479
480
481
            mean,
            rstd,
            x.stride(0),
            0 if not recompute_output else y.stride(0),
            dy.stride(0),
            dx.stride(0),
            dresidual.stride(0) if dresidual is not None else 0,
            dresidual_in.stride(0) if dresidual_in is not None else 0,
            M,
            N,
            eps,
482
            dropout_p,
483
484
485
486
487
488
            rows_per_program,
            is_rms_norm,
            BLOCK_N,
            dresidual is not None,
            dresidual_in is not None,
            bias is not None,
489
            dropout_p > 0.0,
490
        )
491
492
493
    dw = _dw.sum(0).to(weight.dtype)
    db = _db.sum(0).to(bias.dtype) if bias is not None else None
    # Don't need to compute dresidual_in separately in this case
494
    if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
495
496
497
498
499
500
        dresidual_in = dx
    return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)


class LayerNormFn(torch.autograd.Function):
    @staticmethod
501
502
503
504
505
506
507
    def forward(
        ctx,
        x,
        weight,
        bias,
        residual=None,
        eps=1e-6,
508
        dropout_p=0.0,
509
        rowscale=None,
510
511
512
        prenorm=False,
        residual_in_fp32=False,
        is_rms_norm=False,
513
        return_dropout_mask=False,
514
    ):
515
516
517
518
519
520
521
522
523
524
525
526
527
        x_shape_og = x.shape
        # reshape input data into 2D tensor
        x = x.reshape(-1, x.shape[-1])
        if x.stride(-1) != 1:
            x = x.contiguous()
        if residual is not None:
            assert residual.shape == x_shape_og
            residual = residual.reshape(-1, residual.shape[-1])
            if residual.stride(-1) != 1:
                residual = residual.contiguous()
        weight = weight.contiguous()
        if bias is not None:
            bias = bias.contiguous()
528
529
        if rowscale is not None:
            rowscale = rowscale.reshape(-1).contiguous()
530
531
532
533
534
        residual_dtype = (
            residual.dtype
            if residual is not None
            else (torch.float32 if residual_in_fp32 else None)
        )
535
536
537
538
539
540
541
        y, mean, rstd, residual_out, seeds, dropout_mask = _layer_norm_fwd(
            x,
            weight,
            bias,
            eps,
            residual,
            dropout_p=dropout_p,
542
            rowscale=rowscale,
543
544
545
            residual_dtype=residual_dtype,
            is_rms_norm=is_rms_norm,
            return_dropout_mask=return_dropout_mask,
546
        )
547
        ctx.save_for_backward(residual_out, weight, bias, rowscale, seeds, mean, rstd)
548
549
        ctx.x_shape_og = x_shape_og
        ctx.eps = eps
550
        ctx.dropout_p = dropout_p
551
552
        ctx.is_rms_norm = is_rms_norm
        ctx.has_residual = residual is not None
553
        ctx.prenorm = prenorm
554
555
        ctx.x_dtype = x.dtype
        y = y.reshape(x_shape_og)
556
557
558
559
560
561
        residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
        dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
        if not return_dropout_mask:
            return y if not prenorm else (y, residual_out)
        else:
            return (y, dropout_mask) if not prenorm else (y, residual_out, dropout_mask)
562
563
564

    @staticmethod
    def backward(ctx, dy, *args):
565
        x, weight, bias, rowscale, seeds, mean, rstd = ctx.saved_tensors
566
567
568
569
        dy = dy.reshape(-1, dy.shape[-1])
        if dy.stride(-1) != 1:
            dy = dy.contiguous()
        assert dy.shape == x.shape
570
        if ctx.prenorm:
571
572
573
574
575
576
577
            dresidual = args[0]
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
            if dresidual.stride(-1) != 1:
                dresidual = dresidual.contiguous()
            assert dresidual.shape == x.shape
        else:
            dresidual = None
578
579
580
581
582
583
584
585
586
        dx, dw, db, dresidual_in = _layer_norm_bwd(
            dy,
            x,
            weight,
            bias,
            ctx.eps,
            mean,
            rstd,
            dresidual,
587
588
            seeds,
            ctx.dropout_p,
589
            rowscale,
590
591
592
593
594
595
596
597
598
599
600
601
602
            ctx.has_residual,
            ctx.is_rms_norm,
            x_dtype=ctx.x_dtype,
        )
        return (
            dx.reshape(ctx.x_shape_og),
            dw,
            db,
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
            None,
            None,
            None,
            None,
603
604
            None,
            None,
605
            None,
606
        )
607
608


609
610
611
612
613
614
def layer_norm_fn(
    x,
    weight,
    bias,
    residual=None,
    eps=1e-6,
615
    dropout_p=0.0,
616
    rowscale=None,
617
618
619
    prenorm=False,
    residual_in_fp32=False,
    is_rms_norm=False,
620
    return_dropout_mask=False,
621
):
622
623
624
625
626
627
628
    return LayerNormFn.apply(
        x,
        weight,
        bias,
        residual,
        eps,
        dropout_p,
629
        rowscale,
630
631
632
633
634
        prenorm,
        residual_in_fp32,
        is_rms_norm,
        return_dropout_mask,
    )
635
636


637
638
639
640
641
642
643
def rms_norm_fn(
    x,
    weight,
    bias,
    residual=None,
    eps=1e-6,
    dropout_p=0.0,
644
    rowscale=None,
645
646
647
648
649
650
651
652
653
654
655
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
    return LayerNormFn.apply(
        x,
        weight,
        bias,
        residual,
        eps,
        dropout_p,
656
        rowscale,
657
658
659
660
661
        prenorm,
        residual_in_fp32,
        True,
        return_dropout_mask,
    )
662
663
664


class RMSNorm(torch.nn.Module):
665
    def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
666
667
668
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
669
        self.dropout_p = dropout_p
670
671
672
673
674
675
676
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

677
678
679
680
681
682
683
    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
        return rms_norm_fn(
            x,
            self.weight,
            self.bias,
            residual=residual,
            eps=self.eps,
684
            dropout_p=self.dropout_p if self.training else 0.0,
685
686
687
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
        )
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820


class LayerNormLinearFn(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(
        ctx,
        x,
        norm_weight,
        norm_bias,
        linear_weight,
        linear_bias,
        residual=None,
        eps=1e-6,
        prenorm=False,
        residual_in_fp32=False,
        is_rms_norm=False,
    ):
        x_shape_og = x.shape
        # reshape input data into 2D tensor
        x = x.reshape(-1, x.shape[-1])
        if x.stride(-1) != 1:
            x = x.contiguous()
        if residual is not None:
            assert residual.shape == x_shape_og
            residual = residual.reshape(-1, residual.shape[-1])
            if residual.stride(-1) != 1:
                residual = residual.contiguous()
        norm_weight = norm_weight.contiguous()
        if norm_bias is not None:
            norm_bias = norm_bias.contiguous()
        residual_dtype = (
            residual.dtype
            if residual is not None
            else (torch.float32 if residual_in_fp32 else None)
        )
        y, mean, rstd, residual_out = _layer_norm_fwd(
            x,
            norm_weight,
            norm_bias,
            eps,
            residual,
            out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
            residual_dtype=residual_dtype,
            is_rms_norm=is_rms_norm,
        )
        y = y.reshape(x_shape_og)
        dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
        linear_weight = linear_weight.to(dtype)
        linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
        # We don't store y, will be recomputed in the backward pass to save memory
        ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
        ctx.x_shape_og = x_shape_og
        ctx.eps = eps
        ctx.is_rms_norm = is_rms_norm
        ctx.has_residual = residual is not None
        ctx.prenorm = prenorm
        ctx.x_dtype = x.dtype
        ctx.linear_bias_is_none = linear_bias is None
        return out if not prenorm else (out, residual_out.reshape(x_shape_og))

    @staticmethod
    @custom_bwd
    def backward(ctx, dout, *args):
        x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
        dout = dout.reshape(-1, dout.shape[-1])
        dy = F.linear(dout, linear_weight.t())
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
        if dy.stride(-1) != 1:
            dy = dy.contiguous()
        assert dy.shape == x.shape
        if ctx.prenorm:
            dresidual = args[0]
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
            if dresidual.stride(-1) != 1:
                dresidual = dresidual.contiguous()
            assert dresidual.shape == x.shape
        else:
            dresidual = None
        dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
            dy,
            x,
            norm_weight,
            norm_bias,
            ctx.eps,
            mean,
            rstd,
            dresidual,
            ctx.has_residual,
            ctx.is_rms_norm,
            x_dtype=ctx.x_dtype,
            recompute_output=True,
        )
        dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
        return (
            dx.reshape(ctx.x_shape_og),
            dnorm_weight,
            dnorm_bias,
            dlinear_weight,
            dlinear_bias,
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
            None,
            None,
            None,
            None,
        )


def layer_norm_linear_fn(
    x,
    norm_weight,
    norm_bias,
    linear_weight,
    linear_bias,
    residual=None,
    eps=1e-6,
    prenorm=False,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    return LayerNormLinearFn.apply(
        x,
        norm_weight,
        norm_bias,
        linear_weight,
        linear_bias,
        residual,
        eps,
        prenorm,
        residual_in_fp32,
        is_rms_norm,
    )