"library/vscode:/vscode.git/clone" did not exist on "712e464c4e437a5aaa2fe47bb8161b8f1946e501"
Unverified Commit 242cb457 authored by lijinpei's avatar lijinpei Committed by GitHub
Browse files

[Example] Optimize online_softmax example (#934)



* [Example] Optimize online_softmax example

- Y should be output in float16.
- BN needs to be equal to N to be really online.
- On my H100 machine, this increase speedup from 1.424x to 2.788x.

* enhance

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 5ccac4fa
......@@ -11,7 +11,7 @@ def softmax_kernel(
N,
dtype: str = "float16",
) -> "Callable":
BN = min(tl.next_power_of_2(N), 1024)
BN = min(tl.next_power_of_2(N), 8192)
NN = tl.cdiv(N, BN)
accum_dtype = "float"
......@@ -21,7 +21,7 @@ def softmax_kernel(
@T.prim_func
def main(
X: T.Tensor([M, N], dtype),
Y: T.Tensor([M, N], accum_dtype),
Y: T.Tensor([M, N], dtype),
):
with T.Kernel(M, threads=128) as (i_m):
x = T.alloc_fragment([BN], dtype)
......@@ -38,8 +38,7 @@ def softmax_kernel(
T.reduce_max(x, max_x, dim=0, clear=True)
for j in T.Parallel(BN):
exp_x[j] = T.if_then_else(j + i_n * BN < N,
T.exp2(x[j] * scale - max_x[0] * scale), 0)
exp_x[j] = T.exp2(x[j] * scale - max_x[0] * scale)
T.reduce_sum(exp_x, sum_exp_x, dim=0, clear=True)
......@@ -49,9 +48,7 @@ def softmax_kernel(
T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x)
for j in T.Parallel(BN):
if j + i_n * BN < N:
y[j] = T.exp2(x[j] * scale - lse[0])
y[j] = T.exp2(x[j] * scale - lse[0])
T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN])
......@@ -63,7 +60,7 @@ N = 8192
kernel = softmax_kernel(M, N)
dtype = torch.float16
X = torch.randn(M, N, dtype=dtype, device="cuda")
Y = kernel(X).to(dtype)
Y = kernel(X)
Y_ref = X.softmax(dim=1)
torch.testing.assert_close(Y, Y_ref, rtol=1e-2, atol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment