[JAX] Debug distributed attention tests (#1038)
* Remove extra args to fused attention func Signed-off-by:Tim Moon <tmoon@nvidia.com> * Add missing arg to fused attention func Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
Showing
Please register or sign in to comment