@@ -229,12 +314,27 @@ def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, hea
nblocks_n=(seqlen_k+blocksize_n-1)//blocksize_n
nblocks_m=(seqlen_q+blocksize_m-1)//blocksize_m
mmas_n=(blocksize_n+16-1)//16
S_flat=rearrange(S,'b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)',
blocksize_m=blocksize_m,blocksize_n=blocksize_n)
S_converted=rearrange(S_flat,'b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)',