[JAX] Refactor and trim TE JAX Attn testing (#2542)
* Pick a leaner set of combinations for TE JAX CP attn tests such that only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus Signed-off-by:Kshitij Lakhani <klakhani@nvidia.com> * Consolidate the test cases run for different B,S,H,D and QKV layout Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code and comments clean up Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make FP16 + GQA test cross attn instead of self attn to generalize the test Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment