Unverified Commit 0af3fd7c authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[BugFix] Refactor attention kernel to handle OOB positions by filling with...

[BugFix] Refactor attention kernel to handle OOB positions by filling with `-inf` instead of clearing accumulators. (#1222)

* Refactor attention kernel to handle OOB positions by filling with `-inf` instead of clearing accumulators.

* lint

* pre-commit

* Update imports in flash attention test file to use new backward and forward examples for better clarity and consistency.
parent eac96cd7
...@@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
......
...@@ -59,7 +59,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -59,7 +59,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
T.Cast(accum_dtype, -1e30)) T.Cast(accum_dtype, -1e30))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
......
...@@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc ...@@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
......
...@@ -96,7 +96,9 @@ def flashattn(batch, ...@@ -96,7 +96,9 @@ def flashattn(batch,
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
......
...@@ -63,7 +63,9 @@ def flashattn( ...@@ -63,7 +63,9 @@ def flashattn(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
......
...@@ -56,7 +56,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -56,7 +56,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
...@@ -213,6 +215,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -213,6 +215,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0) 0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......
...@@ -52,7 +52,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -52,7 +52,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
...@@ -206,6 +208,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -206,6 +208,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0) 0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT) T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -340,7 +344,7 @@ if __name__ == "__main__": ...@@ -340,7 +344,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads') parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') parser.add_argument('--n_ctx', type=int, default=1048, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag') parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -53,7 +53,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -53,7 +53,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
...@@ -193,6 +195,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): ...@@ -193,6 +195,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0) 0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.wait_wgmma(0) T.wait_wgmma(0)
T.copy(qkT, qkT_cast) T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
......
...@@ -55,7 +55,9 @@ def flashattn(batch, ...@@ -55,7 +55,9 @@ def flashattn(batch,
k_idx = k * block_N + j k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) # We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
...@@ -226,7 +228,7 @@ if __name__ == "__main__": ...@@ -226,7 +228,7 @@ if __name__ == "__main__":
parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument('--is_causal', action='store_true', help='causal', default=False)
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
...@@ -55,7 +55,9 @@ def flashattn(batch, ...@@ -55,7 +55,9 @@ def flashattn(batch,
k_idx = k * block_N + j k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) # We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
......
...@@ -49,7 +49,10 @@ def flashattn(batch, ...@@ -49,7 +49,10 @@ def flashattn(batch,
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) # We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
......
...@@ -49,7 +49,10 @@ def flashattn(batch, ...@@ -49,7 +49,10 @@ def flashattn(batch,
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype)) -T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) # We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
......
...@@ -2,7 +2,7 @@ import tilelang.testing ...@@ -2,7 +2,7 @@ import tilelang.testing
import example_gqa_bwd import example_gqa_bwd
import example_gqa_bwd_wgmma_pipelined import example_gqa_bwd_wgmma_pipelined
import example_mha_bwd import example_mha_bwd_bshd
import example_mha_bwd_bhsd import example_mha_bwd_bhsd
import example_mha_fwd_bhsd_wgmma_pipelined import example_mha_fwd_bhsd_wgmma_pipelined
import example_gqa_fwd_bshd import example_gqa_fwd_bshd
...@@ -10,7 +10,7 @@ import example_mha_fwd_bshd ...@@ -10,7 +10,7 @@ import example_mha_fwd_bshd
import example_gqa_fwd_bshd_wgmma_pipelined import example_gqa_fwd_bshd_wgmma_pipelined
import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_bshd_wgmma_pipelined
import example_mha_fwd_varlen import example_mha_fwd_varlen
import example_mha_bwd_wgmma_pipelined import example_mha_bwd_bshd_wgmma_pipelined
import example_mha_fwd_bhsd import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen import example_gqa_bwd_tma_reduce_varlen
...@@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined(): ...@@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_example_mha_bwd(): def test_example_mha_bwd():
example_mha_bwd.main( example_mha_bwd_bshd.main(
BATCH=1, BATCH=1,
H=16, H=16,
N_CTX=512, N_CTX=512,
...@@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd(): ...@@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined(): def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment