Added dropout verification for flash attention forward (#593)
* saved dropout random number in gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp * modified device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp * added z tensor for dropout storing * added z in example * can compile now * change fp16 xdl to bf16 * fixed some bugs in example * changed fwd file names * fixed some bugs in fwd drop verify * Delete device_grouped_multihead_attention_forward_xdl_cshuffle * Fwd drop verify2 (#585) * fixed some bugs in fwd drop verify * Delete device_grouped_multihead_attention_forward_xdl_cshuffle * added group fwd mha dropout verify * added dropout verify for grouped mha fp16 fwd * added bf16 fwd attn dropout verify * added dropout verify to batched mha fwd * added batched fla fwd bf16 dropout verify * changed some format * added switch for lse storing in attn fwd * added switch for lse storing in attn fwd * resolved conflicts in refere...
Showing
Please register or sign in to comment