Unverified Commit 242cb457 authored by lijinpei's avatar lijinpei Committed by GitHub
Browse files

[Example] Optimize online_softmax example (#934)



* [Example] Optimize online_softmax example

- Y should be output in float16.
- BN needs to be equal to N to be really online.
- On my H100 machine, this increase speedup from 1.424x to 2.788x.

* enhance

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 5ccac4fa
...@@ -11,7 +11,7 @@ def softmax_kernel( ...@@ -11,7 +11,7 @@ def softmax_kernel(
N, N,
dtype: str = "float16", dtype: str = "float16",
) -> "Callable": ) -> "Callable":
BN = min(tl.next_power_of_2(N), 1024) BN = min(tl.next_power_of_2(N), 8192)
NN = tl.cdiv(N, BN) NN = tl.cdiv(N, BN)
accum_dtype = "float" accum_dtype = "float"
...@@ -21,7 +21,7 @@ def softmax_kernel( ...@@ -21,7 +21,7 @@ def softmax_kernel(
@T.prim_func @T.prim_func
def main( def main(
X: T.Tensor([M, N], dtype), X: T.Tensor([M, N], dtype),
Y: T.Tensor([M, N], accum_dtype), Y: T.Tensor([M, N], dtype),
): ):
with T.Kernel(M, threads=128) as (i_m): with T.Kernel(M, threads=128) as (i_m):
x = T.alloc_fragment([BN], dtype) x = T.alloc_fragment([BN], dtype)
...@@ -38,8 +38,7 @@ def softmax_kernel( ...@@ -38,8 +38,7 @@ def softmax_kernel(
T.reduce_max(x, max_x, dim=0, clear=True) T.reduce_max(x, max_x, dim=0, clear=True)
for j in T.Parallel(BN): for j in T.Parallel(BN):
exp_x[j] = T.if_then_else(j + i_n * BN < N, exp_x[j] = T.exp2(x[j] * scale - max_x[0] * scale)
T.exp2(x[j] * scale - max_x[0] * scale), 0)
T.reduce_sum(exp_x, sum_exp_x, dim=0, clear=True) T.reduce_sum(exp_x, sum_exp_x, dim=0, clear=True)
...@@ -49,9 +48,7 @@ def softmax_kernel( ...@@ -49,9 +48,7 @@ def softmax_kernel(
T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x) T.copy(X[i_m, i_n * BN:(i_n + 1) * BN], x)
for j in T.Parallel(BN): for j in T.Parallel(BN):
y[j] = T.exp2(x[j] * scale - lse[0])
if j + i_n * BN < N:
y[j] = T.exp2(x[j] * scale - lse[0])
T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN]) T.copy(y, Y[i_m, i_n * BN:(i_n + 1) * BN])
...@@ -63,7 +60,7 @@ N = 8192 ...@@ -63,7 +60,7 @@ N = 8192
kernel = softmax_kernel(M, N) kernel = softmax_kernel(M, N)
dtype = torch.float16 dtype = torch.float16
X = torch.randn(M, N, dtype=dtype, device="cuda") X = torch.randn(M, N, dtype=dtype, device="cuda")
Y = kernel(X).to(dtype) Y = kernel(X)
Y_ref = X.softmax(dim=1) Y_ref = X.softmax(dim=1)
torch.testing.assert_close(Y, Y_ref, rtol=1e-2, atol=1e-2) torch.testing.assert_close(Y, Y_ref, rtol=1e-2, atol=1e-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment