online_softmax.py 2.01 KB
Newer Older
1
2
3
4
5
6
7
import torch
import tilelang as tl
import tilelang.language as T
from tilelang.profiler import do_bench
from typing import Callable


8
@tl.jit(out_idx=[1])
9
10
11
12
13
def softmax_kernel(
    M,
    N,
    dtype: str = "float16",
) -> "Callable":
14
    BN = min(tl.next_power_of_2(N), 8192)
15
16
17
18
19
20
21
22
23
    NN = tl.cdiv(N, BN)

    accum_dtype = "float"

    scale = 1.44269504  # log2(e)

    @T.prim_func
    def main(
            X: T.Tensor([M, N], dtype),
24
            Y: T.Tensor([M, N], dtype),
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    ):
        with T.Kernel(M, threads=128) as (i_m):
            x = T.alloc_fragment([BN], dtype)
            y = T.alloc_fragment([BN], dtype)
            lse = T.alloc_fragment([1], accum_dtype)
            max_x = T.alloc_fragment([1], dtype)
            exp_x = T.alloc_fragment([BN], accum_dtype)
            sum_exp_x = T.alloc_fragment([1], accum_dtype)
            T.fill(lse, -T.infinity(accum_dtype))

            for i_n in T.Pipelined(0, NN):
                T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x)

                T.reduce_max(x, max_x, dim=0, clear=True)

                for j in T.Parallel(BN):
41
                    exp_x[j] = T.exp2(x[j] * scale - max_x[0] * scale)
42
43
44
45
46
47
48
49
50

                T.reduce_sum(exp_x, sum_exp_x, dim=0, clear=True)

                lse[0] = max_x[0] * scale + T.log2(T.exp2(lse[0] - max_x[0] * scale) + sum_exp_x[0])

            for i_n in T.Pipelined(0, NN):
                T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x)

                for j in T.Parallel(BN):
51
                    y[j] = T.exp2(x[j] * scale - lse[0])
52
53
54
55
56
57
58
59

                T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN])

    return main


M = 8192
N = 8192
60
kernel = softmax_kernel(M, N)
61
62
dtype = torch.float16
X = torch.randn(M, N, dtype=dtype, device="cuda")
63
Y = kernel(X)
64
65
66
67
68
69
70
71
72
Y_ref = X.softmax(dim=1)

torch.testing.assert_close(Y, Y_ref, rtol=1e-2, atol=1e-2)

t1 = do_bench(lambda: X.softmax(dim=1), warmup=25, rep=100)
t2 = do_bench(lambda: kernel(X), warmup=25, rep=100)
print(f"torch latency: {t1:.3f} ms")
print(f"TileLang latency: {t2:.3f} ms")
print(f"Speedup: {t1/t2:.3f}x")