poly_norm.py 11 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import operator

import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count
from liger_kernel.ops.utils import set_large_grf_mode
from liger_kernel.utils import is_npu_available

if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
    try:
        from triton.language.extra.libdevice import rsqrt
    except ModuleNotFoundError:
        from triton.language.extra.cuda.libdevice import rsqrt
else:
    from triton.language.math import rsqrt


@triton.jit
def _poly_norm_forward_kernel(
    Y_ptr,
    Y_row_stride,
    X_ptr,
    X_row_stride,
    W_ptr,  # weight: [3] for [w0, w1, w2]
    B_ptr,  # bias: scalar
    RSTD_ptr,  # cache rstd for backward: shape (n_rows, 3)
    RSTD_row_stride,
    n_cols,
    eps,
    BLOCK_SIZE: tl.constexpr,
):
    """
    PolyNorm formula:
        y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
        where norm(u) = u / sqrt(mean(u²) + ε)

    Reference:
    1. https://github.com/BryceZhuo/PolyCom/
    2. https://arxiv.org/pdf/2411.03884

    Cache rstd values for backward pass
    """
    row_idx = tl.program_id(0).to(tl.int64)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    # Load pointers
    Y_ptr += row_idx * Y_row_stride
    X_ptr += row_idx * X_row_stride
    RSTD_ptr += row_idx * RSTD_row_stride

    # Load input row
    X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)

    # Load weights and bias
    w0 = tl.load(W_ptr + 0)
    w1 = tl.load(W_ptr + 1)
    w2 = tl.load(W_ptr + 2)
    b = tl.load(B_ptr)

    # Compute x³, x², x
    X_pow3 = X_row * X_row * X_row
    X_pow2 = X_row * X_row
    X_pow1 = X_row

    # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
    mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
    rstd_3 = rsqrt(mean_square_3 + eps)
    norm_x3 = X_pow3 * rstd_3

    # Compute norm(x²)
    mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
    rstd_2 = rsqrt(mean_square_2 + eps)
    norm_x2 = X_pow2 * rstd_2

    # Compute norm(x)
    mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
    rstd_1 = rsqrt(mean_square_1 + eps)
    norm_x1 = X_pow1 * rstd_1

    # Cache rstd values for backward
    tl.store(RSTD_ptr + 0, rstd_3)
    tl.store(RSTD_ptr + 1, rstd_2)
    tl.store(RSTD_ptr + 2, rstd_1)

    # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
    Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b

    # Store output
    tl.store(Y_ptr + col_offsets, Y_row, mask=mask)


@triton.jit
def _poly_norm_backward_kernel(
    dY_ptr,
    dY_row_stride,
    dX_ptr,
    dX_row_stride,
    X_ptr,
    X_row_stride,
    W_ptr,
    RSTD_ptr,
    RSTD_row_stride,
    dW_ptr,  # shape: (n_programs, 3)
    dW_row_stride,
    dB_ptr,  # shape: (n_programs,)
    n_rows,
    n_cols,
    rows_per_program: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    """
    PolyNorm Backward Kernel Gradient:
        ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]

    where:
        - D_p = RMS(x^p) = 1/rstd_p
        - S_p = sum(grad * x^p) over the row
        - d = n_cols
        - p ∈ {3, 2, 1}
    """
    row_block_id = tl.program_id(0).to(tl.int64)
    row_start = row_block_id * rows_per_program
    row_end = min((row_block_id + 1) * rows_per_program, n_rows)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    # Initialize accumulators for weight and bias gradients (scalars)
    dW0_acc = 0.0
    dW1_acc = 0.0
    dW2_acc = 0.0
    dB_acc = 0.0

    # Load weights
    w0 = tl.load(W_ptr + 0).to(tl.float32)
    w1 = tl.load(W_ptr + 1).to(tl.float32)
    w2 = tl.load(W_ptr + 2).to(tl.float32)

    for row_idx in range(row_start, row_end):
        dy_base = dY_ptr + row_idx * dY_row_stride
        x_base = X_ptr + row_idx * X_row_stride
        dx_base = dX_ptr + row_idx * dX_row_stride
        rstd_base = RSTD_ptr + row_idx * RSTD_row_stride

        dY_row = tl.load(dy_base + col_offsets, mask=mask, other=0.0).to(tl.float32)
        X_row = tl.load(x_base + col_offsets, mask=mask, other=0.0).to(tl.float32)

        # Load cached rstd values
        rstd_3 = tl.load(rstd_base + 0).to(tl.float32)
        rstd_2 = tl.load(rstd_base + 1).to(tl.float32)
        rstd_1 = tl.load(rstd_base + 2).to(tl.float32)

        # Compute powers
        X_pow3 = X_row * X_row * X_row
        X_pow2 = X_row * X_row
        X_pow1 = X_row

        # Accumulate bias gradient: dB = sum(dY)
        dB_acc += tl.sum(dY_row, axis=0)

        # Compute gradient w.r.t. input using closed-form formula
        # For p=3: ∂L/∂x from w0 * norm(x³)
        S_3 = tl.sum(dY_row * X_pow3, axis=0)  # scalar
        grad_x_3 = w0 * (
            3.0 * X_pow2 * rstd_3 * dY_row
            - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
        )

        # For p=2: ∂L/∂x from w1 * norm(x²)
        S_2 = tl.sum(dY_row * X_pow2, axis=0)  # scalar
        grad_x_2 = w1 * (
            2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
        )

        # For p=1: ∂L/∂x from w2 * norm(x)
        S_1 = tl.sum(dY_row * X_pow1, axis=0)  # scalar
        grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)

        # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
        dW0_acc += rstd_3 * S_3
        dW1_acc += rstd_2 * S_2
        dW2_acc += rstd_1 * S_1

        # Total gradient
        dX_row = grad_x_3 + grad_x_2 + grad_x_1

        # Store gradient
        tl.store(dx_base + col_offsets, dX_row, mask=mask)

    # Store accumulated gradients (scalars)
    tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
    tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
    tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
    tl.store(dB_ptr + row_block_id, dB_acc)


def poly_norm_forward(X, W, B, eps=1e-6):
    """
    PolyNorm Forward Pass

    Args:
        X: input tensor of shape (*, H) where H is hidden dimension
        W: weight tensor of shape (3,) for [w0, w1, w2]
        B: bias scalar tensor
        eps: epsilon for numerical stability

    Returns:
        Y: output tensor of same shape as X
        X: reshaped input (for backward)
        RSTD: cached rstd values (for backward)
        BLOCK_SIZE: block size used
        num_warps: number of warps used
    """
    shape = X.shape
    dim = shape[-1]
    X = X.view(-1, dim)
    n_rows, n_cols = X.shape
    BLOCK_SIZE, num_warps = calculate_settings(n_cols)

    # RSTD is to cache rstd for each row
    Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
    RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)

    # Check constraints
    assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
    assert B.numel() == 1, "Bias must be a scalar"

    # XPU-specific optimization
    kernel_args = {}
    if X.device.type == "xpu":
        set_large_grf_mode(kernel_args)

    # Launch kernel
    _poly_norm_forward_kernel[(n_rows,)](
        Y,
        Y.stride(0),
        X,
        X.stride(0),
        W,
        B,
        RSTD,
        RSTD.stride(0),
        n_cols,
        eps,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
        **kernel_args,
    )

    return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps


def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
    """
    PolyNorm Backward Pass

    Args:
        dY: gradient of output
        X: input tensor (already reshaped to 2D)
        W: weight tensor
        RSTD: cached rstd values from forward
        BLOCK_SIZE: block size from forward
        num_warps: number of warps from forward
        in_place: whether to in-place modify dY to store dX (saves memory)

    Returns:
        dX: gradient w.r.t. input
        dW: gradient w.r.t. weight
        dB: gradient w.r.t. bias
    """
    shape = dY.shape
    dim = shape[-1]
    dY = dY.view(-1, dim)
    n_rows, n_cols = dY.shape

    # Get number of SMs for parallelization
    import math

    sm_count = 1
    if X.device.type == "cuda":
        sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
    elif X.device.type == "xpu":
        sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
    elif X.device.type == "npu":
        sm_count = get_npu_core_count()

    # Allocate or reuse gradients
    if in_place is True:
        dX = dY
    else:
        dX = torch.zeros_like(dY)

    _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
    _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)

    rows_per_program = math.ceil(n_rows / sm_count)
    grid = (sm_count,)

    # XPU-specific optimization
    kernel_args = {}
    if X.device.type == "xpu":
        set_large_grf_mode(kernel_args)

    # Launch backward kernel
    _poly_norm_backward_kernel[grid](
        dY,
        dY.stride(0),
        dX,
        dX.stride(0),
        X,
        X.stride(0),
        W,
        RSTD,
        RSTD.stride(0),
        _dW,
        _dW.stride(0),
        _dB,
        n_rows,
        n_cols,
        rows_per_program,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
        **kernel_args,
    )

    # Reduce gradients across SMs
    dX = dX.view(*shape)
    dW = _dW.sum(dim=0).to(W.dtype)
    dB = _dB.sum().to(W.dtype)

    return dX, dW, dB


class LigerPolyNormFunction(torch.autograd.Function):
    """
    PolyNorm Function with forward and backward pass

    PolyNorm formula:
        y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
        where norm(u) = u / sqrt(mean(u²) + ε)

    Backward uses closed-form gradient:
        ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
    """

    @staticmethod
    @ensure_contiguous
    def forward(ctx, X, W, B, eps=1e-6, in_place=True):
        """
        Args:
            X: input tensor of shape (B, T, H) or (BxT, H)
            W: weight tensor of shape (3,) for [w0, w1, w2]
            B: bias scalar
            eps: epsilon for numerical stability
            in_place: whether to in-place modify grad_output in backward (saves memory)

        Returns:
            Y: output tensor of same shape as X
        """
        Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.num_warps = num_warps
        ctx.in_place = in_place
        ctx.save_for_backward(X, W, RSTD)
        return Y

    @staticmethod
    @ensure_contiguous
    def backward(ctx, grad_output):
        """
        Args:
            grad_output: gradient of output

        Returns:
            dX, dW, dB: gradients w.r.t. X, W, B
        """
        X, W, RSTD = ctx.saved_tensors
        dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
        return dX, dW, dB, None, None