Added dropout verification for flash attention forward (#593)
* saved dropout random number in gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
* modified device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
* added z tensor for dropout storing
* added z in example
* can compile now
* change fp16 xdl to bf16
* fixed some bugs in example
* changed fwd file names
* fixed some bugs in fwd drop verify
* Delete device_grouped_multihead_attention_forward_xdl_cshuffle
* Fwd drop verify2 (#585)
* fixed some bugs in fwd drop verify
* Delete device_grouped_multihead_attention_forward_xdl_cshuffle
* added group fwd mha dropout verify
* added dropout verify for grouped mha fp16 fwd
* added bf16 fwd attn dropout verify
* added dropout verify to batched mha fwd
* added batched fla fwd bf16 dropout verify
* changed some format
* added switch for lse storing in attn fwd
* added switch for lse storing in attn fwd
* resolved conflicts in reference_dropout.hpp
---------
Co-authored-by:
ltqin <letao.qin@amd.com>
Showing
Please register or sign in to comment