layer_norm.py 34.1 KB
Newer Older
1
# Copyright (c) 2024, Tri Dao.
2
# Implement dropout + residual + layer_norm / rms_norm.
3
4
5
6
7
8
9
10
11
12

# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.

import math

import torch
import torch.nn.functional as F
13
from torch.cuda.amp import custom_fwd, custom_bwd
14
15
16
17
18

import triton
import triton.language as tl


19
20
21
22
23
def layer_norm_ref(
    x,
    weight,
    bias,
    residual=None,
24
25
26
    x1=None,
    weight1=None,
    bias1=None,
27
28
    eps=1e-6,
    dropout_p=0.0,
29
    rowscale=None,
30
31
    prenorm=False,
    dropout_mask=None,
32
    dropout_mask1=None,
33
34
    upcast=False,
):
35
36
    dtype = x.dtype
    if upcast:
37
        x = x.float()
38
39
40
        weight = weight.float()
        bias = bias.float() if bias is not None else None
        residual = residual.float() if residual is not None else residual
41
42
43
44
45
        x1 = x1.float() if x1 is not None else None
        weight1 = weight1.float() if weight1 is not None else None
        bias1 = bias1.float() if bias1 is not None else None
    if x1 is not None:
        assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
47
    if rowscale is not None:
        x = x * rowscale[..., None]
48
49
50
51
52
    if dropout_p > 0.0:
        if dropout_mask is not None:
            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
        else:
            x = F.dropout(x, p=dropout_p)
53
54
55
56
57
58
59
        if x1 is not None:
            if dropout_mask1 is not None:
                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
            else:
                x1 = F.dropout(x1, p=dropout_p)
    if x1 is not None:
        x = x + x1
60
61
    if residual is not None:
        x = (x + residual).to(x.dtype)
62
63
64
    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
        dtype
    )
65
66
67
68
69
70
71
    if weight1 is None:
        return out if not prenorm else (out, x)
    else:
        out1 = F.layer_norm(
            x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
        ).to(dtype)
        return (out, out1) if not prenorm else (out, out1, x)
72
73


74
75
76
77
78
def rms_norm_ref(
    x,
    weight,
    bias,
    residual=None,
79
80
81
    x1=None,
    weight1=None,
    bias1=None,
82
83
    eps=1e-6,
    dropout_p=0.0,
84
    rowscale=None,
85
86
    prenorm=False,
    dropout_mask=None,
87
    dropout_mask1=None,
88
89
    upcast=False,
):
90
91
    dtype = x.dtype
    if upcast:
92
        x = x.float()
93
94
95
        weight = weight.float()
        bias = bias.float() if bias is not None else None
        residual = residual.float() if residual is not None else residual
96
97
98
99
100
        x1 = x1.float() if x1 is not None else None
        weight1 = weight1.float() if weight1 is not None else None
        bias1 = bias1.float() if bias1 is not None else None
    if x1 is not None:
        assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
102
    if rowscale is not None:
        x = x * rowscale[..., None]
103
104
105
106
107
    if dropout_p > 0.0:
        if dropout_mask is not None:
            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
        else:
            x = F.dropout(x, p=dropout_p)
108
109
110
111
112
113
114
        if x1 is not None:
            if dropout_mask1 is not None:
                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
            else:
                x1 = F.dropout(x1, p=dropout_p)
    if x1 is not None:
        x = x + x1
115
116
117
    if residual is not None:
        x = (x + residual).to(x.dtype)
    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
119
120
121
122
123
124
125
    out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
    if weight1 is None:
        return out if not prenorm else (out, x)
    else:
        out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
            dtype
        )
        return (out, out1) if not prenorm else (out, out1, x)
126
127
128
129
130
131
132
133
134
135
136


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=1),
        triton.Config({}, num_warps=2),
        triton.Config({}, num_warps=4),
        triton.Config({}, num_warps=8),
        triton.Config({}, num_warps=16),
        triton.Config({}, num_warps=32),
    ],
137
    key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
138
139
140
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
141
142
143
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
144
145
146
147
148
149
150
@triton.jit
def _layer_norm_fwd_1pass_kernel(
    X,  # pointer to the input
    Y,  # pointer to the output
    W,  # pointer to the weights
    B,  # pointer to the biases
    RESIDUAL,  # pointer to the residual
151
152
153
154
    X1,
    W1,
    B1,
    Y1,
155
    RESIDUAL_OUT,  # pointer to the residual
156
    ROWSCALE,
157
158
    SEEDS,  # Dropout seeds for each row
    DROPOUT_MASK,
159
160
161
162
163
164
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride_x_row,  # how much to increase the pointer when moving by 1 row
    stride_y_row,
    stride_res_row,
    stride_res_out_row,
165
166
167
    stride_x1_row,
    stride_y1_row,
    M,  # number of rows in X
168
169
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
170
    dropout_p,  # Dropout probability
171
172
173
    IS_RMS_NORM: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_RESIDUAL: tl.constexpr,
174
    STORE_RESIDUAL_OUT: tl.constexpr,
175
    HAS_BIAS: tl.constexpr,
176
177
    HAS_DROPOUT: tl.constexpr,
    STORE_DROPOUT_MASK: tl.constexpr,
178
    HAS_ROWSCALE: tl.constexpr,
179
180
181
    HAS_X1: tl.constexpr,
    HAS_W1: tl.constexpr,
    HAS_B1: tl.constexpr,
182
183
184
185
186
187
188
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
189
    if STORE_RESIDUAL_OUT:
190
        RESIDUAL_OUT += row * stride_res_out_row
191
192
193
194
    if HAS_X1:
        X1 += row * stride_x1_row
    if HAS_W1:
        Y1 += row * stride_y1_row
195
196
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
197
    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
198
199
200
    if HAS_ROWSCALE:
        rowscale = tl.load(ROWSCALE + row).to(tl.float32)
        x *= rowscale
201
202
203
204
205
206
207
    if HAS_DROPOUT:
        # Compute dropout mask
        # 7 rounds is good enough, and reduces register pressure
        keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
        x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
        if STORE_DROPOUT_MASK:
            tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    if HAS_X1:
        x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
        if HAS_ROWSCALE:
            rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
            x1 *= rowscale
        if HAS_DROPOUT:
            # Compute dropout mask
            # 7 rounds is good enough, and reduces register pressure
            keep_mask = (
                tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
            )
            x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
            if STORE_DROPOUT_MASK:
                tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
        x += x1
223
    if HAS_RESIDUAL:
224
        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
225
        x += residual
226
    if STORE_RESIDUAL_OUT:
227
228
229
230
        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
    if not IS_RMS_NORM:
        mean = tl.sum(x, axis=0) / N
        tl.store(Mean + row, mean)
231
        xbar = tl.where(cols < N, x - mean, 0.0)
232
233
        var = tl.sum(xbar * xbar, axis=0) / N
    else:
234
        xbar = tl.where(cols < N, x, 0.0)
235
236
237
238
239
240
241
242
243
244
245
246
        var = tl.sum(xbar * xbar, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation
    mask = cols < N
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    if HAS_BIAS:
        b = tl.load(B + cols, mask=mask).to(tl.float32)
    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
    y = x_hat * w + b if HAS_BIAS else x_hat * w
    # Write output
    tl.store(Y + cols, y, mask=mask)
247
248
249
250
251
252
    if HAS_W1:
        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
        if HAS_B1:
            b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
        y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
        tl.store(Y1 + cols, y1, mask=mask)
253
254


255
def _layer_norm_fwd(
256
257
258
259
260
    x,
    weight,
    bias,
    eps,
    residual=None,
261
262
263
    x1=None,
    weight1=None,
    bias1=None,
264
    dropout_p=0.0,
265
    rowscale=None,
266
267
268
269
    out_dtype=None,
    residual_dtype=None,
    is_rms_norm=False,
    return_dropout_mask=False,
270
):
271
272
    if residual is not None:
        residual_dtype = residual.dtype
273
274
275
276
277
278
279
280
281
282
    M, N = x.shape
    assert x.stride(-1) == 1
    if residual is not None:
        assert residual.stride(-1) == 1
        assert residual.shape == (M, N)
    assert weight.shape == (N,)
    assert weight.stride(-1) == 1
    if bias is not None:
        assert bias.stride(-1) == 1
        assert bias.shape == (N,)
283
284
285
286
287
288
289
290
291
292
    if x1 is not None:
        assert x1.shape == x.shape
        assert rowscale is None
        assert x1.stride(-1) == 1
    if weight1 is not None:
        assert weight1.shape == (N,)
        assert weight1.stride(-1) == 1
    if bias1 is not None:
        assert bias1.shape == (N,)
        assert bias1.stride(-1) == 1
293
294
295
    if rowscale is not None:
        assert rowscale.is_contiguous()
        assert rowscale.shape == (M,)
296
    # allocate output
297
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
298
    assert y.stride(-1) == 1
299
300
301
302
303
    if weight1 is not None:
        y1 = torch.empty_like(y)
        assert y1.stride(-1) == 1
    else:
        y1 = None
304
305
306
307
    if (
        residual is not None
        or (residual_dtype is not None and residual_dtype != x.dtype)
        or dropout_p > 0.0
308
        or rowscale is not None
309
        or x1 is not None
310
311
312
313
    ):
        residual_out = torch.empty(
            M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
        )
314
315
316
        assert residual_out.stride(-1) == 1
    else:
        residual_out = None
317
318
    mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
    rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
319
    if dropout_p > 0.0:
320
321
322
        seeds = torch.randint(
            2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
        )
323
324
325
    else:
        seeds = None
    if return_dropout_mask and dropout_p > 0.0:
326
        dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
327
328
    else:
        dropout_mask = None
329
330
331
332
333
334
    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    with torch.cuda.device(x.device.index):
335
336
337
338
339
340
        _layer_norm_fwd_1pass_kernel[(M,)](
            x,
            y,
            weight,
            bias,
            residual,
341
342
343
344
            x1,
            weight1,
            bias1,
            y1,
345
            residual_out,
346
            rowscale,
347
348
            seeds,
            dropout_mask,
349
350
351
352
353
354
            mean,
            rstd,
            x.stride(0),
            y.stride(0),
            residual.stride(0) if residual is not None else 0,
            residual_out.stride(0) if residual_out is not None else 0,
355
356
357
            x1.stride(0) if x1 is not None else 0,
            y1.stride(0) if y1 is not None else 0,
            M,
358
359
            N,
            eps,
360
            dropout_p,
361
362
363
364
365
            is_rms_norm,
            BLOCK_N,
            residual is not None,
            residual_out is not None,
            bias is not None,
366
367
            dropout_p > 0.0,
            dropout_mask is not None,
368
            rowscale is not None,
369
        )
370
    # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    if dropout_mask is not None and x1 is not None:
        dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
    else:
        dropout_mask1 = None
    return (
        y,
        y1,
        mean,
        rstd,
        residual_out if residual_out is not None else x,
        seeds,
        dropout_mask,
        dropout_mask1,
    )
385
386
387
388
389
390
391
392
393
394
395


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=1),
        triton.Config({}, num_warps=2),
        triton.Config({}, num_warps=4),
        triton.Config({}, num_warps=8),
        triton.Config({}, num_warps=16),
        triton.Config({}, num_warps=32),
    ],
396
    key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
397
398
399
400
)
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
401
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
402
403
404
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
405
406
407
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
408
409
410
411
    X,  # pointer to the input
    W,  # pointer to the weights
    B,  # pointer to the biases
    Y,  # pointer to the output to be recomputed
412
413
414
415
416
    DY,  # pointer to the output gradient
    DX,  # pointer to the input gradient
    DW,  # pointer to the partial sum of weights gradient
    DB,  # pointer to the partial sum of biases gradient
    DRESIDUAL,
417
418
419
420
421
    W1,
    DY1,
    DX1,
    DW1,
    DB1,
422
    DRESIDUAL_IN,
423
    ROWSCALE,
424
    SEEDS,
425
426
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
427
428
429
430
431
    stride_x_row,  # how much to increase the pointer when moving by 1 row
    stride_y_row,
    stride_dy_row,
    stride_dx_row,
    stride_dres_row,
432
433
    stride_dy1_row,
    stride_dx1_row,
434
435
436
437
    stride_dres_in_row,
    M,  # number of rows in X
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
438
    dropout_p,
439
440
441
442
443
444
    rows_per_program,
    IS_RMS_NORM: tl.constexpr,
    BLOCK_N: tl.constexpr,
    HAS_DRESIDUAL: tl.constexpr,
    STORE_DRESIDUAL: tl.constexpr,
    HAS_BIAS: tl.constexpr,
445
    HAS_DROPOUT: tl.constexpr,
446
    HAS_ROWSCALE: tl.constexpr,
447
448
449
    HAS_DY1: tl.constexpr,
    HAS_DX1: tl.constexpr,
    HAS_B1: tl.constexpr,
450
451
452
453
454
    RECOMPUTE_OUTPUT: tl.constexpr,
):
    # Map the program id to the elements of X, DX, and DY it should compute.
    row_block_id = tl.program_id(0)
    row_start = row_block_id * rows_per_program
455
456
    if row_start >= M:
        return
457
458
459
460
461
462
463
464
465
    cols = tl.arange(0, BLOCK_N)
    mask = cols < N
    X += row_start * stride_x_row
    if HAS_DRESIDUAL:
        DRESIDUAL += row_start * stride_dres_row
    if STORE_DRESIDUAL:
        DRESIDUAL_IN += row_start * stride_dres_in_row
    DY += row_start * stride_dy_row
    DX += row_start * stride_dx_row
466
467
468
469
    if HAS_DY1:
        DY1 += row_start * stride_dy1_row
    if HAS_DX1:
        DX1 += row_start * stride_dx1_row
470
471
472
473
    if RECOMPUTE_OUTPUT:
        Y += row_start * stride_y_row
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    if RECOMPUTE_OUTPUT and HAS_BIAS:
474
        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
475
476
    if HAS_DY1:
        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
477
478
479
    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
    if HAS_BIAS:
        db = tl.zeros((BLOCK_N,), dtype=tl.float32)
480
481
482
483
    if HAS_DY1:
        dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
        if HAS_B1:
            db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
484
485
486
487
488
    row_end = min((row_block_id + 1) * rows_per_program, M)
    for row in range(row_start, row_end):
        # Load data to SRAM
        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
489
490
        if HAS_DY1:
            dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
491
492
493
494
495
        if not IS_RMS_NORM:
            mean = tl.load(Mean + row)
        rstd = tl.load(Rstd + row)
        # Compute dx
        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
496
        xhat = tl.where(mask, xhat, 0.0)
497
498
499
500
501
502
503
        if RECOMPUTE_OUTPUT:
            y = xhat * w + b if HAS_BIAS else xhat * w
            tl.store(Y + cols, y, mask=mask)
        wdy = w * dy
        dw += dy * xhat
        if HAS_BIAS:
            db += dy
504
505
506
507
508
        if HAS_DY1:
            wdy += w1 * dy1
            dw1 += dy1 * xhat
            if HAS_B1:
                db1 += dy1
509
510
511
512
513
514
515
516
517
518
519
520
521
        if not IS_RMS_NORM:
            c1 = tl.sum(xhat * wdy, axis=0) / N
            c2 = tl.sum(wdy, axis=0) / N
            dx = (wdy - (xhat * c1 + c2)) * rstd
        else:
            c1 = tl.sum(xhat * wdy, axis=0) / N
            dx = (wdy - xhat * c1) * rstd
        if HAS_DRESIDUAL:
            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
            dx += dres
        # Write dx
        if STORE_DRESIDUAL:
            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
522
523
524
525
526
527
528
529
530
        if HAS_DX1:
            if HAS_DROPOUT:
                keep_mask = (
                    tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
                )
                dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
            else:
                dx1 = dx
            tl.store(DX1 + cols, dx1, mask=mask)
531
532
533
        if HAS_DROPOUT:
            keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
            dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
534
535
536
        if HAS_ROWSCALE:
            rowscale = tl.load(ROWSCALE + row).to(tl.float32)
            dx *= rowscale
537
538
539
540
541
542
543
544
545
546
547
        tl.store(DX + cols, dx, mask=mask)

        X += stride_x_row
        if HAS_DRESIDUAL:
            DRESIDUAL += stride_dres_row
        if STORE_DRESIDUAL:
            DRESIDUAL_IN += stride_dres_in_row
        if RECOMPUTE_OUTPUT:
            Y += stride_y_row
        DY += stride_dy_row
        DX += stride_dx_row
548
549
550
551
        if HAS_DY1:
            DY1 += stride_dy1_row
        if HAS_DX1:
            DX1 += stride_dx1_row
552
553
554
    tl.store(DW + row_block_id * N + cols, dw, mask=mask)
    if HAS_BIAS:
        tl.store(DB + row_block_id * N + cols, db, mask=mask)
555
556
557
558
    if HAS_DY1:
        tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
        if HAS_B1:
            tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
559
560


561
562
563
564
565
566
567
568
569
def _layer_norm_bwd(
    dy,
    x,
    weight,
    bias,
    eps,
    mean,
    rstd,
    dresidual=None,
570
571
572
    dy1=None,
    weight1=None,
    bias1=None,
573
574
    seeds=None,
    dropout_p=0.0,
575
    rowscale=None,
576
    has_residual=False,
577
    has_x1=False,
578
579
580
581
    is_rms_norm=False,
    x_dtype=None,
    recompute_output=False,
):
582
583
584
585
586
587
588
589
590
591
592
593
    M, N = x.shape
    assert x.stride(-1) == 1
    assert dy.stride(-1) == 1
    assert dy.shape == (M, N)
    if dresidual is not None:
        assert dresidual.stride(-1) == 1
        assert dresidual.shape == (M, N)
    assert weight.shape == (N,)
    assert weight.stride(-1) == 1
    if bias is not None:
        assert bias.stride(-1) == 1
        assert bias.shape == (N,)
594
595
596
597
598
599
600
601
602
603
    if dy1 is not None:
        assert weight1 is not None
        assert dy1.shape == dy.shape
        assert dy1.stride(-1) == 1
    if weight1 is not None:
        assert weight1.shape == (N,)
        assert weight1.stride(-1) == 1
    if bias1 is not None:
        assert bias1.shape == (N,)
        assert bias1.stride(-1) == 1
604
605
    if seeds is not None:
        assert seeds.is_contiguous()
606
        assert seeds.shape == (M if not has_x1 else M * 2,)
607
608
609
    if rowscale is not None:
        assert rowscale.is_contiguous()
        assert rowscale.shape == (M,)
610
    # allocate output
611
612
613
614
615
    dx = (
        torch.empty_like(x)
        if x_dtype is None
        else torch.empty(M, N, dtype=x_dtype, device=x.device)
    )
616
    dresidual_in = (
617
        torch.empty_like(x)
618
619
        if has_residual
        and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
620
        else None
621
    )
622
    dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
623
    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
624
625
    if recompute_output:
        assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
626
627
628
629
630
631
632
633

    # Less than 64KB per feature: enqueue fused kernel
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
    _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
634
635
636
637
638
    _db = (
        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
        if bias is not None
        else None
    )
639
640
    _dw1 = torch.empty_like(_dw) if weight1 is not None else None
    _db1 = torch.empty_like(_db) if bias1 is not None else None
641
642
643
    rows_per_program = math.ceil(M / sm_count)
    grid = (sm_count,)
    with torch.cuda.device(x.device.index):
644
645
646
647
648
649
650
651
652
653
        _layer_norm_bwd_kernel[grid](
            x,
            weight,
            bias,
            y,
            dy,
            dx,
            _dw,
            _db,
            dresidual,
654
655
656
657
658
            weight1,
            dy1,
            dx1,
            _dw1,
            _db1,
659
            dresidual_in,
660
            rowscale,
661
            seeds,
662
663
664
665
666
667
668
            mean,
            rstd,
            x.stride(0),
            0 if not recompute_output else y.stride(0),
            dy.stride(0),
            dx.stride(0),
            dresidual.stride(0) if dresidual is not None else 0,
669
670
            dy1.stride(0) if dy1 is not None else 0,
            dx1.stride(0) if dx1 is not None else 0,
671
672
673
674
            dresidual_in.stride(0) if dresidual_in is not None else 0,
            M,
            N,
            eps,
675
            dropout_p,
676
677
678
679
680
681
            rows_per_program,
            is_rms_norm,
            BLOCK_N,
            dresidual is not None,
            dresidual_in is not None,
            bias is not None,
682
            dropout_p > 0.0,
683
        )
684
685
    dw = _dw.sum(0).to(weight.dtype)
    db = _db.sum(0).to(bias.dtype) if bias is not None else None
686
687
    dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
    db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
688
    # Don't need to compute dresidual_in separately in this case
689
    if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
690
        dresidual_in = dx
691
692
693
694
695
696
697
    if has_x1 and dropout_p == 0.0:
        dx1 = dx
    return (
        (dx, dw, db, dresidual_in, dx1, dw1, db1)
        if not recompute_output
        else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
    )
698
699
700
701


class LayerNormFn(torch.autograd.Function):
    @staticmethod
702
703
704
705
706
707
    def forward(
        ctx,
        x,
        weight,
        bias,
        residual=None,
708
709
710
        x1=None,
        weight1=None,
        bias1=None,
711
        eps=1e-6,
712
        dropout_p=0.0,
713
        rowscale=None,
714
715
716
        prenorm=False,
        residual_in_fp32=False,
        is_rms_norm=False,
717
        return_dropout_mask=False,
718
    ):
719
720
721
722
723
724
725
726
727
728
        x_shape_og = x.shape
        # reshape input data into 2D tensor
        x = x.reshape(-1, x.shape[-1])
        if x.stride(-1) != 1:
            x = x.contiguous()
        if residual is not None:
            assert residual.shape == x_shape_og
            residual = residual.reshape(-1, residual.shape[-1])
            if residual.stride(-1) != 1:
                residual = residual.contiguous()
729
730
731
732
733
734
        if x1 is not None:
            assert x1.shape == x_shape_og
            assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
            x1 = x1.reshape(-1, x1.shape[-1])
            if x1.stride(-1) != 1:
                x1 = x1.contiguous()
735
736
737
        weight = weight.contiguous()
        if bias is not None:
            bias = bias.contiguous()
738
739
740
741
        if weight1 is not None:
            weight1 = weight1.contiguous()
        if bias1 is not None:
            bias1 = bias1.contiguous()
742
743
        if rowscale is not None:
            rowscale = rowscale.reshape(-1).contiguous()
744
745
746
747
748
        residual_dtype = (
            residual.dtype
            if residual is not None
            else (torch.float32 if residual_in_fp32 else None)
        )
749
        y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
750
751
752
753
754
            x,
            weight,
            bias,
            eps,
            residual,
755
756
757
            x1,
            weight1,
            bias1,
758
            dropout_p=dropout_p,
759
            rowscale=rowscale,
760
761
762
            residual_dtype=residual_dtype,
            is_rms_norm=is_rms_norm,
            return_dropout_mask=return_dropout_mask,
763
        )
764
765
766
        ctx.save_for_backward(
            residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
        )
767
768
        ctx.x_shape_og = x_shape_og
        ctx.eps = eps
769
        ctx.dropout_p = dropout_p
770
771
        ctx.is_rms_norm = is_rms_norm
        ctx.has_residual = residual is not None
772
        ctx.has_x1 = x1 is not None
773
        ctx.prenorm = prenorm
774
775
        ctx.x_dtype = x.dtype
        y = y.reshape(x_shape_og)
776
        y1 = y1.reshape(x_shape_og) if y1 is not None else None
777
778
        residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
        dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
779
        dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
780
        if not return_dropout_mask:
781
782
783
784
            if weight1 is None:
                return y if not prenorm else (y, residual_out)
            else:
                return (y, y1) if not prenorm else (y, y1, residual_out)
785
        else:
786
787
788
789
790
791
792
793
794
795
796
797
            if weight1 is None:
                return (
                    (y, dropout_mask, dropout_mask1)
                    if not prenorm
                    else (y, residual_out, dropout_mask, dropout_mask1)
                )
            else:
                return (
                    (y, y1, dropout_mask, dropout_mask1)
                    if not prenorm
                    else (y, y1, residual_out, dropout_mask, dropout_mask1)
                )
798
799
800

    @staticmethod
    def backward(ctx, dy, *args):
801
        x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
802
803
804
805
        dy = dy.reshape(-1, dy.shape[-1])
        if dy.stride(-1) != 1:
            dy = dy.contiguous()
        assert dy.shape == x.shape
806
807
808
809
810
811
812
813
        if weight1 is not None:
            dy1, args = args[0], args[1:]
            dy1 = dy1.reshape(-1, dy1.shape[-1])
            if dy1.stride(-1) != 1:
                dy1 = dy1.contiguous()
            assert dy1.shape == x.shape
        else:
            dy1 = None
814
        if ctx.prenorm:
815
816
817
818
819
820
821
            dresidual = args[0]
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
            if dresidual.stride(-1) != 1:
                dresidual = dresidual.contiguous()
            assert dresidual.shape == x.shape
        else:
            dresidual = None
822
        dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
823
824
825
826
827
828
829
830
            dy,
            x,
            weight,
            bias,
            ctx.eps,
            mean,
            rstd,
            dresidual,
831
832
833
            dy1,
            weight1,
            bias1,
834
835
            seeds,
            ctx.dropout_p,
836
            rowscale,
837
            ctx.has_residual,
838
            ctx.has_x1,
839
840
841
842
843
844
845
846
            ctx.is_rms_norm,
            x_dtype=ctx.x_dtype,
        )
        return (
            dx.reshape(ctx.x_shape_og),
            dw,
            db,
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
847
848
849
            dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
            dw1,
            db1,
850
851
852
853
            None,
            None,
            None,
            None,
854
855
            None,
            None,
856
            None,
857
        )
858
859


860
861
862
863
864
def layer_norm_fn(
    x,
    weight,
    bias,
    residual=None,
865
866
867
    x1=None,
    weight1=None,
    bias1=None,
868
    eps=1e-6,
869
    dropout_p=0.0,
870
    rowscale=None,
871
872
873
    prenorm=False,
    residual_in_fp32=False,
    is_rms_norm=False,
874
    return_dropout_mask=False,
875
):
876
877
878
879
880
    return LayerNormFn.apply(
        x,
        weight,
        bias,
        residual,
881
882
883
        x1,
        weight1,
        bias1,
884
885
        eps,
        dropout_p,
886
        rowscale,
887
888
889
890
891
        prenorm,
        residual_in_fp32,
        is_rms_norm,
        return_dropout_mask,
    )
892
893


894
895
896
897
898
def rms_norm_fn(
    x,
    weight,
    bias,
    residual=None,
899
900
901
    x1=None,
    weight1=None,
    bias1=None,
902
903
    eps=1e-6,
    dropout_p=0.0,
904
    rowscale=None,
905
906
907
908
909
910
911
912
913
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
    return LayerNormFn.apply(
        x,
        weight,
        bias,
        residual,
914
915
916
        x1,
        weight1,
        bias1,
917
918
        eps,
        dropout_p,
919
        rowscale,
920
921
922
923
924
        prenorm,
        residual_in_fp32,
        True,
        return_dropout_mask,
    )
925
926
927


class RMSNorm(torch.nn.Module):
928

929
    def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
930
931
932
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
933
934
935
936
        if dropout_p > 0.0:
            self.drop = torch.nn.Dropout(dropout_p)
        else:
            self.drop = None
937
938
939
940
941
942
943
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

944
945
946
947
948
949
950
    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
        return rms_norm_fn(
            x,
            self.weight,
            self.bias,
            residual=residual,
            eps=self.eps,
951
            dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
952
953
954
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
        )
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087


class LayerNormLinearFn(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(
        ctx,
        x,
        norm_weight,
        norm_bias,
        linear_weight,
        linear_bias,
        residual=None,
        eps=1e-6,
        prenorm=False,
        residual_in_fp32=False,
        is_rms_norm=False,
    ):
        x_shape_og = x.shape
        # reshape input data into 2D tensor
        x = x.reshape(-1, x.shape[-1])
        if x.stride(-1) != 1:
            x = x.contiguous()
        if residual is not None:
            assert residual.shape == x_shape_og
            residual = residual.reshape(-1, residual.shape[-1])
            if residual.stride(-1) != 1:
                residual = residual.contiguous()
        norm_weight = norm_weight.contiguous()
        if norm_bias is not None:
            norm_bias = norm_bias.contiguous()
        residual_dtype = (
            residual.dtype
            if residual is not None
            else (torch.float32 if residual_in_fp32 else None)
        )
        y, mean, rstd, residual_out = _layer_norm_fwd(
            x,
            norm_weight,
            norm_bias,
            eps,
            residual,
            out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
            residual_dtype=residual_dtype,
            is_rms_norm=is_rms_norm,
        )
        y = y.reshape(x_shape_og)
        dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
        linear_weight = linear_weight.to(dtype)
        linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
        # We don't store y, will be recomputed in the backward pass to save memory
        ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
        ctx.x_shape_og = x_shape_og
        ctx.eps = eps
        ctx.is_rms_norm = is_rms_norm
        ctx.has_residual = residual is not None
        ctx.prenorm = prenorm
        ctx.x_dtype = x.dtype
        ctx.linear_bias_is_none = linear_bias is None
        return out if not prenorm else (out, residual_out.reshape(x_shape_og))

    @staticmethod
    @custom_bwd
    def backward(ctx, dout, *args):
        x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
        dout = dout.reshape(-1, dout.shape[-1])
        dy = F.linear(dout, linear_weight.t())
        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
        if dy.stride(-1) != 1:
            dy = dy.contiguous()
        assert dy.shape == x.shape
        if ctx.prenorm:
            dresidual = args[0]
            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
            if dresidual.stride(-1) != 1:
                dresidual = dresidual.contiguous()
            assert dresidual.shape == x.shape
        else:
            dresidual = None
        dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
            dy,
            x,
            norm_weight,
            norm_bias,
            ctx.eps,
            mean,
            rstd,
            dresidual,
            ctx.has_residual,
            ctx.is_rms_norm,
            x_dtype=ctx.x_dtype,
            recompute_output=True,
        )
        dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
        return (
            dx.reshape(ctx.x_shape_og),
            dnorm_weight,
            dnorm_bias,
            dlinear_weight,
            dlinear_bias,
            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
            None,
            None,
            None,
            None,
        )


def layer_norm_linear_fn(
    x,
    norm_weight,
    norm_bias,
    linear_weight,
    linear_bias,
    residual=None,
    eps=1e-6,
    prenorm=False,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    return LayerNormLinearFn.apply(
        x,
        norm_weight,
        norm_bias,
        linear_weight,
        linear_bias,
        residual,
        eps,
        prenorm,
        residual_in_fp32,
        is_rms_norm,
    )