Unverified Commit a2a27814 authored by Varuna Jayasiri's avatar Varuna Jayasiri Committed by GitHub
Browse files

[EXAMPLE] In the flash attention example keep the max of all blocks seen in...

[EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability (#1148)

* Keep the max of all blocks seen in scores_max for stability

* ruff formatting
parent 041d4a06
......@@ -86,6 +86,10 @@ def flashattn(batch,
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment