Unverified Commit 3ab93cd7 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Keep max score attention across blocks in FlashAttention for...


[Enhancement] Keep max score attention across blocks in FlashAttention for better numerical stablity (#1269)

* Implement max score retention across blocks in FlashAttention for improved stability

* fix manual pipeline parameters

* Update examples/flash_attention/example_gqa_fwd_varlen.py
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix typo

* more

* fix a previous typo

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent b3d6f03c
......@@ -95,6 +95,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
......
......@@ -178,6 +178,8 @@ def fast_flashattn(
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])
for i in T.Parallel(block_M):
if m_prev[i] == -T.infinity(accum_dtype):
......
......@@ -171,6 +171,8 @@ def fast_flashattn(
T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])
for i in T.Parallel(block_M):
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
......
......@@ -99,6 +99,8 @@ def flashattn_fwd(
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
......
......@@ -105,6 +105,8 @@ def flashattn(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
......@@ -181,7 +183,7 @@ def flashattn(
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
......
......@@ -96,6 +96,8 @@ def flashattn_fwd(
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
......
......@@ -95,6 +95,8 @@ def flashattn(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
......
......@@ -98,6 +98,8 @@ def flashattn(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
......@@ -174,7 +176,7 @@ def flashattn(
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
......
......@@ -105,8 +105,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i],
scores_max[i], scores_max_prev[i])
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -95,8 +95,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i],
scores_max[i], scores_max_prev[i])
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -92,6 +92,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -91,6 +91,8 @@ def flashmla_decode(batch,
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......@@ -157,6 +159,8 @@ def flashmla_decode(batch,
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -74,6 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......@@ -148,6 +150,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -93,6 +93,8 @@ def mla_decode_tilelang(batch,
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......@@ -176,6 +178,8 @@ def mla_decode_tilelang(batch,
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -98,6 +98,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
......
......@@ -104,7 +104,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
T.reduce_max(acc_s, out=m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
......@@ -137,6 +139,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
......@@ -324,6 +328,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
......@@ -356,6 +362,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
......
......@@ -74,6 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
......
......@@ -147,6 +147,8 @@ def sparse_mla_fwd(
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
......
......@@ -164,6 +164,8 @@ def sparse_mla_fwd(
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
......@@ -198,6 +200,8 @@ def sparse_mla_fwd(
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
......
......@@ -77,6 +77,8 @@ def flash_attention(
# Compute the maximum value per row on dimension 1 (block_N)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# Compute the factor by which we need to rescale previous partial sums
for i in T.Parallel(block_M):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment