[JAX] Refactor fused attention (#711)
* Remove unused headers Signed-off-by:Reese Wang <rewang@nvidia.com> * Unify the fused attn workspace size cpp code Signed-off-by:
Reese Wang <rewang@nvidia.com> * Reduce the skipped cases Signed-off-by:
Reese Wang <rewang@nvidia.com> * Rename self/cross attention to qkvpacked/kvpacked Signed-off-by:
Reese Wang <rewang@nvidia.com> * Update attention mask docs Signed-off-by:
Reese Wang <rewang@nvidia.com> * Refine the attn mask implementations Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
Showing
Please register or sign in to comment