[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)
* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by:Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Assert in fused attn bwd pass for sm100 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Add check for sm100 Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Represent attn bias using enum instead of string Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment