[JAX] Add THD + SWA unit tests (#1390)
* Fix SWA mask for THD and forcing seqlen_kv >= seqlen_q for SWA Signed-off-by:Reese Wang <rewang@nvidia.com> * Generalize sliding window mask Signed-off-by:
Reese Wang <rewang@nvidia.com> * Fix pylint Signed-off-by:
Reese Wang <rewang@nvidia.com> --------- Signed-off-by:
Reese Wang <rewang@nvidia.com>
Showing
Please register or sign in to comment