Unverified Commit 0af3fd7c authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[BugFix] Refactor attention kernel to handle OOB positions by filling with...

[BugFix] Refactor attention kernel to handle OOB positions by filling with `-inf` instead of clearing accumulators. (#1222)

* Refactor attention kernel to handle OOB positions by filling with `-inf` instead of clearing accumulators.

* lint

* pre-commit

* Update imports in flash attention test file to use new backward and forward examples for better clarity and consistency.
parent eac96cd7
......@@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
......
......@@ -59,7 +59,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
T.Cast(accum_dtype, -1e30))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
......
......@@ -54,7 +54,9 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
......
......@@ -96,7 +96,9 @@ def flashattn(batch,
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......
......@@ -63,7 +63,9 @@ def flashattn(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......
......@@ -56,7 +56,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
......@@ -213,6 +215,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......
......@@ -52,7 +52,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
......@@ -206,6 +208,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
......@@ -340,7 +344,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--n_ctx', type=int, default=1048, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
args = parser.parse_args()
......
......@@ -53,7 +53,9 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
......@@ -193,6 +195,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.wait_wgmma(0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
......
......@@ -55,7 +55,9 @@ def flashattn(batch,
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......@@ -226,7 +228,7 @@ if __name__ == "__main__":
parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--is_causal', action='store_true', help='causal', default=False)
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
......@@ -55,7 +55,9 @@ def flashattn(batch,
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......
......@@ -49,7 +49,10 @@ def flashattn(batch,
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......
......@@ -49,7 +49,10 @@ def flashattn(batch,
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
......
......@@ -2,7 +2,7 @@ import tilelang.testing
import example_gqa_bwd
import example_gqa_bwd_wgmma_pipelined
import example_mha_bwd
import example_mha_bwd_bshd
import example_mha_bwd_bhsd
import example_mha_fwd_bhsd_wgmma_pipelined
import example_gqa_fwd_bshd
......@@ -10,7 +10,7 @@ import example_mha_fwd_bshd
import example_gqa_fwd_bshd_wgmma_pipelined
import example_mha_fwd_bshd_wgmma_pipelined
import example_mha_fwd_varlen
import example_mha_bwd_wgmma_pipelined
import example_mha_bwd_bshd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen
......@@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined():
@tilelang.testing.requires_cuda
def test_example_mha_bwd():
example_mha_bwd.main(
example_mha_bwd_bshd.main(
BATCH=1,
H=16,
N_CTX=512,
......@@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
@tilelang.testing.requires_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment