[JAX] Fix correctness of JAX fused attention with CP and improve numerics...
[JAX] Fix correctness of JAX fused attention with CP and improve numerics check in unit tests (#1282)
Fix correctness of JAX fused attention with CP.
Signed-off-by:
Michael Goldfarb <mgoldfarb@nvidia.com>
Showing
Please register or sign in to comment