[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls (#2392)
* Make BSHD default for Unfused DPA, DPA and MHA in TE JAX Signed-off-by:Kshitij Janardan Lakhani <klakhani@nvidia.com> * Remove explicit transpose_batch set for BSHD for DPA in JAX quickstart Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Add warnings in DPA and MHA to warn users of change defaults to BSHD instead of SBHD Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * Minimize the scope of when to trigger warnings for changed defaults for transpose_batch_sequence Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Kshitij Janardan Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment