[JAX] Add BRCM support for THD (#2242)
* Add BRCM support when creating a test mask for fused attn Signed-off-by:Kshitij Lakhani <klakhani@nvidia.com> * Add support for BRCM to correctly generate the mask needed for calculating the seqlens and offsets for THD Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Skip drop=0 and no_bias case for BRCM as cuDNN does not suport this Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Skip BRCM test cases where max_seqlen_q > max_seqlen_kv Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Refactor the segment id run length code for BRCM seqoffset and seqlens calculations Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix the drop inequality skip condition in fused attn Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * nit: Adjust the BRCM id name in the test to make it consistent Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix the brcm mask condition. Fix the condition for cross atnn type pattern to only apply for brcm Change the num segments per sequence to 3 instead of 2 Reduce one test pattern data size and make it such that it triggers brcm Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint errors Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Fix incorrectly changed dtype to numpy bool_ rather than native python bool Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Restore the numsegments to earlier value Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> * Add example for THD BRCM Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> --------- Signed-off-by:
Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment