[JAX] Default to fused attention in JAX DPA (#2363)
* Default to fused attention in JAX DPA Signed-off-by:Kshitij Lakhani <klakhani@nvidia.com> * Consolidate documentation for DPA in JAX Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> * Correctly update the documentation for defaults in JAX DPA Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Signed-off-by:
Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Co-authored-by:
greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment