Add support for SWA (left, right) with FusedAttention (#2477)
* SWA (left, right) with FusedAttention changes cherry-picked from https://github.com/NVIDIA/TransformerEngine/pull/1369 Signed-off-by:Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test_kv_cache failures Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * remove unnecessary comments Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * fix some more filter issues, address feedback Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * make conditions more accurate Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * add cp tests to test swa (left, right) Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove dead code and make conditions better Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feedback form Charlene Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * small er Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * plumb `bottom_right_diagonal` through jax Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * plumb `bottom_right_diagonal` through jax Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add missing fields Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * use proper mask type in CP Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by:
Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment