Unverified Commit fd0cd12e authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add CP + THD + AG + Striped>1 + SWA support (#2379)



* Add generic stripe_height support for load balancing
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix imports in test for deprecated jax.experimental.pjit
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add test case for stripe_height greater than 1. Add stripe_height arg to reordering methods
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add Striped 1 and 4 test cases. Refactor the Load Balancing test case. Fix the incorrect shape in striping inverser reordering
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Modify test code for CP + AG + THD + stripe height greater than 1
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add stripe_height arg to fused attn and fused attn fwd API. Add appropriate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* TMP: Throwaway testing commit
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add comments in primitive registration process
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* TMP: Throwaway test commit
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Undoing incorrect rebase/merge leftovers
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* TMP: Throwaway test commits
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add support for calculating q and kv seqlens and offsets per rank for CP+THD+AG+SW+Striped>1 primitive
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Augment jax primitive register code comments
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Fix the array sizes and padding values returned for seqlens and offsets to fit what the fused attn primitive non cp computation
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add support in new primitive for softmax_offset related changes. Put in missing primitive registering line in again. Increase the seqoffsets arrays lengths by 1
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Add new set of helper functions for seqlens and seqoffsets fo AG+THD+CP+Stripe>1 which accounts for batching and seq offsets size b+1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add backward primitive for CP+THD+AG+Striped>1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Modify tests for backward primitive for CP+THD+AG+Striped>1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Move stripe_height along with other static args in fused_attn_bwd rule. Fix typo in CP+AG+TH+Striped>1 primitive
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Code clean up: remove older version for calculating seqlens and offsets for CP+AG+THD+striped>1 primitive
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add test for CP+THD+AG+Striped>1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix missing var
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add SWA tests for AG+Striped>1+CP+THD+SWA
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Restoring test code
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove assert preventing SWA code path in CP+AG+Striped primitive
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Parametrize num_segments_per_seq in tests
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Clean up test code
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Clean up test code in TE common
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Clean up debug statements
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Rename stripe_height to stripe_size
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Code clean up and add additional comments
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

nit: Apply suggestions from code review
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

Fix type on fused attn tests
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Fix seqoffsets length to be passed onto FusedAttn primitive as it is b and not b+1 needed by cuDNN
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove commented code
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

Fix linting issues
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Fix incorrect greptile change
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip THD test cases for CP + AG + Dual chunk. Skip BSHD cases for CP + AG + Striped>1. Correct the layout and shapr parameters passed to the tests
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Pass stripe_size explicitly for ring attn tests for THD cases
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove TODO
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Explicitly fail if THD + AG is being used with a non padding causal mask
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nit: Correct the ID for the test dist fused attn tests to account for cp*2 which is done under the hood
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Set num_segments_per_seq defaults to None instead of 0
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Augment comments. Add ValueError for stripe_size=0
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Test only 1 num_segments_per_seq combination for CP+AG+THD+Striped>1+SWA instead of 2. Modify the num segments and window size to easily to debug values
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Default stripe_size to None instead of 0. Modify stripe_size check for <=0 instead of ==0
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove incorrectly added file
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Explicitly pass zero sized arrays for seg ids and pos in the CP + AG + Striped primitive rather than using the seqlens or the offsets as placeholders
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add a deep dive doc for CP+THD+AG+Stripe>1+SWA regarding design considerations and decisions
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Put docs and pngs into it's separate dir
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Replace png screenshots with markdown coe blocks for the attention patterns. Remove unecessary pngs
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Add doc file to index.rst. Fix grammatical errors
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent f0572aa5
{
"cells": [
{
"cell_type": "markdown",
"id": "14efeb1e",
"metadata": {},
"source": [
"## Deep Dive into CP + THD + AG + Striped>1 + SWA support for Transformer Engine JAX\n",
"This feature was merged as part of [PR 2379](https://github.com/NVIDIA/TransformerEngine/pull/2379/) and was made available in Transformer Engine v2.11. This document addresses 3 fundamental questions about the design considerations and the implementation logic for this feature."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16f738c7",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "2f31119f",
"metadata": {},
"source": [
"### Question 1: Why choose Striped>1 ?\n",
"\n",
"Prior to the addition of this feature, Transformer Engine JAX attention already supported load balancing via a striping pattern, i.e., `stripe_size=1` for `CP + THD + P2P(Ring) + Striped + SWA`. However, this reordering technique does not lend itself well to an all-gathered (post-AG) pattern. The following example illustrates this distinction. For this example, `cp_size=4`, `num_segments=4`, `window_size=(8,0)`, and the pattern is for a single rank after striped reordering has been performed: \n",
"\n",
"#### I. Striped (`stripe_size=1`)\n",
"- Such a staggered pattern is not supported by cuDNN\n",
"- One possible way to express this with cuDNN support is by treating each `q` token as a segment, thereby producing 16 segments with varying `kv` token counts. However, this is very inefficient and does not scale well as max_seqlens increases\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - 1 1 1 1 1 1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 4 4 4 4 - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 4 4 4 4 4 4 4 - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 1: Post load balancing using stripe_size=1 and post AG attention pattern for a single cp rank </figcaption>\n",
"</figure>\n",
"\n",
"\n",
"#### II. Striped > 1 (`stripe_size > 1`)\n",
"- This pattern is supported by cuDNN, with a suggested `stripe_size=128`\n",
"- The mask type supported by `CP + THD + AG + Striped>1 + SWA` is `PADDING_CAUSAL_MASK`; however, to express the pattern below, each rank executes THD + SWA using `PADDING_BOTTOM_RIGHT_CAUSAL_MASK`\n",
"- `max_num_segments_for_rank` needs to be estimated. The estimation formula used is: `max_seqlens // (stripe_size * cp_size) + max_num_segments`\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 4 4 - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 2: Post load balancing using stripe_size=4 and post AG attention pattern for a single cp rank </figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "6eddfa7a",
"metadata": {},
"source": [
"### Question 2: Why is there a need for separate helper functions for calculating seqlens and offsets ?\n",
"\n",
"The seqlens and offsets are calculated by the fused attn JAX primitives (both, CP and non-CP) so that they can be passed down to `fused_attn_arbitrary_seqlen_fwd_impl()` / `fused_attn_arbitrary_seqlen_bwd_impl()`, where it is translated before passing down to the cuDNN FE layer. The current (Transformer Engine v2.10) calculation of seqlens and offsets entails the CP primitive passing the sharded segment_ids, segment_pos, seq_lens, seq_offsets stuffed in a SequenceDescriptor object (a convenience class provided for packing these 4 tensors) to the `FusedAttnPrimitive`, which in turn calls `get_seqlens_and_offsets()` on the SequenceDescriptor object. \n",
"\n",
"If `get_seqlens_and_offsets()` receives a SequenceDescriptor object with seq_lens and seq_offsets populated and, segment_ids, segment_pos with size=0, it returns the seq_lens and seq_ofsets as it is (for e.g. `CP + BSHD + AG`). However, if `get_seqlens_and_offsets()` receives a SequenceDescriptor object with segment_ids and segment_pos populated and, seq_lens, seq_offsets with size=0, it first constructs a mask using the segment_ids and segment_pos and then extracts the seq_lens and seq_offsets from it and then returns it (for e.g. `CP + THD + P2P`).\n",
"\n",
"The problem with the current approach of calculating a mask followed by extracting the seq_lens and seq_offsets is that it is unable to express the patterns seen in `CP + THD + AG`. Below is one such example: \n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 3: Example 1 for problem using mask path in get_seqlens_and_offsets() for attention pattern (post striping and AG) . </figcaption>\n",
"</figure>\n",
"\n",
"Here, ideally, the two sections of the segment 3 should be split into two different segments (segment 3_1 formed using rows 9-12 and segment 3_2 formed using rows 13-16) as cuDNN does not support segment 3's entire staggered shape (as discussed earlier) , however, the mask route is unable to make this distinction, and it ends up treating it as one large segment thereby performing unnecessary computations of the padded regions in segment 3(rows 9-12 )\n",
"\n",
"In the below example, the mask route takes the `kv_seqlens` for segment 1 to be 6 and masks it using Bottom Right Causal Mask rather than taking `kv_seqlens` of 4 and masks it using Bottom Right Causal Mask, resulting in incorrect results\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 4: Example 2 for problem using mask path in get_seqlens_and_offsets() for attention pattern (post striping and AG) </figcaption>\n",
"</figure>\n",
"\n",
"The second case can be resolved in the mask path, but that would require adding CP specific details to the non-CP FusedAttn primitive which would contaminate it. Besides, resolving the first case would be even trickier with this approach. Due to it being incompatible with the design of FusedAttn primitive and inadequate to express the pattern needed for `CP + THD + AG` fully, separate helper functions were created which calculate the seqlens and seqoffsets, without creating a mask, hence also being O(N) space."
]
},
{
"cell_type": "markdown",
"id": "3cc4a12c",
"metadata": {},
"source": [
"### Question 3: What is the implementation logic for the separate helper functions ?\n",
"\n",
"This section discusses the implementation logic for two of these four helper functions which serve as a reference, as the other two are using similar principles. Consider the test example in the code block, for which, `cp_size=4`, `stripe_size=4`, `max_seqlens=64`, `num_segments=2` and no SWA for simplicity. seg_1 has 8 valid tokens + 13 padded tokens and seg_2 has 31 valid tokens + 1 padded token. The 0 is used to explicitly show the padded region of seg_1 which is reordered, but for computation purposes it is equivalent to any of the `-` marked elements.\n",
"\n",
"```\n",
"segment_ids_q_0_reordered = segment_ids_kv_0_reordered = jnp.array([[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]])\n",
"\n",
"segment_pos_q_0_reordered = segment_pos_kv_0_reordered = jnp.array([[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]])\n",
"\n",
"segment_ids_kv_0_seed12_ag_inv_reordered = jnp.array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"\n",
"segment_pos_kv_0_seed12_ag_inv_reordered= jnp.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"```\n",
"\n",
"```\n",
"1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"1 1 1 1 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - - -\n",
"- - - - - - - - - - - - - - - - - - - - - 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 - - - - - - - - - - - -\n",
"```\n",
"<figure align=\"left\">\n",
"<figcaption> Figure 5: An example of post striped reordering and AG attention pattern on a single rank.</figcaption>\n",
"</figure>\n",
"\n",
"#### I. Implementation logic for q_seqlens_for_striped_for_rank()\n",
"**What is the objective/logic ?**\n",
"- Create a new set of segment ids for this rank such that:\n",
" - It gets rid of padding information as it does not contribute to the seqlens calculation\n",
" - It has the ability to identify ”new segments” being created from the same original segment\n",
"- Use this new set of segment ids to calculate the seqlens\n",
"\n",
"**Example walkthrough**\n",
"1. Calculate the non-zero indices (where seg ids !=0)\n",
"2. Calculate the valid seg ids and valid seg pos (i.e. index into seg ids and seg pos using the non-zero indices)\n",
" - `valid_segment_ids=[[1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0]]`\n",
" - `valid_segment_pos=[[0, 1, 2, 3, 11, 12, 13, 14, 27, 28, 29, 30, 0, 0, 0, 0]]`\n",
" - Ignore the 0s at the end of the two arrays as they are just for padding to a static length\n",
"3. Find locations where a q segment change/break happens. A segment change happens when: \n",
" - there is a change in valid_segment_ids OR \n",
" - `valid_segment_pos[i+1] != valid_segment_pos[i]`\n",
" - `segment_changes=[[True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]]`\n",
"4. Perform a cumulative sum on the segment changes: \n",
" - `new_segment_ids=[[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 5, 6, 7]]`\n",
"5. Filter out the valid indices only and pad at the end with 0s upto static length (these are our “new” segment indices without padding)\n",
" - `new_segment_ids_filtered=[[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]`\n",
" - Notice here that the large chunk of 8 q token rows (rows 9-16 in Fig 5) gets broken down into 2 \"new\" segments of 4 q token rows each,\n",
" which is a pattern that cuDNN supports and it ensures that wasted computation for padded regions of rows 9-12 is not performed, which was the\n",
" case in Fig 3\n",
"6. Perform a bin count and pad with -1s upto `max_num_segments_per_seq_for_rank`\n",
" - `seqlens_with_neg1_padding[[ 4, 4, 4, -1, -1, -1, -1]]`\n",
"\n",
"\n",
"#### II. Implementation logic for kv_seqoffsets_for_striped_for_rank()\n",
"**What is the objective/logic ?**\n",
"- Get the original segment ids for those locations where segment changes happen (arr1)\n",
" - Each segment has a known kv offset, hence if we know which original segment id a \"new\" segment is associated with we can find it's kv offset\n",
" - So, for e.g., in Fig 5, all valid tokens of seg_3 have the same kv offset, so even if this gets split into a 2 \"new\" segments, we can procure the offset for both using a mapping of original seg-ids to kv offset \n",
"- Get the segment ids for those locations where segment changes happen in the AG tensor (arr2)\n",
" - This is used to create a kind of mapping between original seg-ids to kv offset\n",
"- Pick values from arr2 mapping for the \"new\" segment ids collected in arr1\n",
"\n",
"**Example walkthrough**\n",
"1. Find locations where a kv segment pos change/break happens and mask out zero seg ids. A segment change happens when: \n",
" - `kv_segment_pos[i+1] != kv_segment_pos[i]`\n",
" - `segment_changes_masked=[[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]`\n",
"2. Get the indices where the segment changes happen and the segment ids associated with them:\n",
" - `segment_changes_indices=[[0, 8, 12, -1, -1, -1, -1, -1, -1]]`\n",
" - `[[1, 2, 2, -1, -1, -1, -1, -1, -1]]`\n",
"3. Find the segment pos changes/break for the AG seg pos and mask out zero seg ids\n",
" - `segment_changes_masked_ag=[[True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]]`\n",
"4. Get indices where the segment changes happen for the AG seg pos (this works as a mapping between segment ids and kv offsets)\n",
" - `segment_changes_ag_indices=[[0, 21, -1, -1, -1, -1, -1, -1, -1]]`\n",
"5. Get the seq offsets by indexing into segment_changes_ag_indices using segment_changes_indices :\n",
" - `kv_seq_offsets[[0, 21, 21, -1, -1, -1, -1, -1, -1]]`\n",
"\n",
"The implementation details for `q_seqoffsets_for_striped_for_rank()` and `kv_seqlens_for_striped_for_rank()` can be found in [PR 2379](https://github.com/NVIDIA/TransformerEngine/pull/2379/)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -56,3 +56,4 @@ Transformer Engine documentation
api/c/index
debug
examples/attention/attention.ipynb
examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb
......@@ -327,9 +327,9 @@ DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
]
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
# Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCPx2-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCPx2-16-64"),
]
......@@ -351,12 +351,14 @@ class TestDistributedContextParallelSelfAttn:
use_shardy,
use_scan_ring=False,
window_size=None,
stripe_size=None,
num_segments_per_seq=None,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
if not load_balanced and (
cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER
):
pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
......@@ -382,7 +384,6 @@ class TestDistributedContextParallelSelfAttn:
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
runner = FusedAttnRunner(
batch,
seqlen,
......@@ -401,6 +402,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape,
window_size,
SeqDescFormat.SegmentIDs,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -453,7 +456,7 @@ class TestDistributedContextParallelSelfAttn:
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
......@@ -470,6 +473,8 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
):
if qkv_layout.is_thd():
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
......@@ -486,6 +491,72 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=True,
)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED")],
)
@pytest.mark.parametrize(
"stripe_size",
[pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")],
)
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((5, 0), id="window_size(8, 0)"),
],
)
@pytest.mark.parametrize(
"num_segments_per_seq",
[pytest.param(5, id="SEG-5")],
)
def test_context_parallel_allgather_striped_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
window_size,
stripe_size,
num_segments_per_seq,
):
if not qkv_layout.is_thd():
pytest.skip("Only THD layout is supported for CP + AG + Striped attention")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
use_shardy=False,
window_size=window_size,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
......@@ -514,6 +585,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -577,6 +650,8 @@ class TestDistributedContextParallelSelfAttn:
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
# Set the stripe size to 1 (ring attention only support stripe_size=1)
stripe_size = 1 if qkv_layout.is_thd() else None
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -592,6 +667,7 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=False,
use_scan_ring=use_scan,
window_size=window_size,
stripe_size=stripe_size,
)
@pytest_parametrize_wrapper(
......@@ -616,6 +692,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
):
kv_groups = 8
# Set the stripe size to 1 (ring attention only support stripe_size=1)
stripe_size = 1 if qkv_layout.is_thd() else None
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -630,6 +708,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
stripe_size=stripe_size,
)
......@@ -639,31 +718,39 @@ REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
"L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}
REORDER_STRATEGY = [
pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"),
pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"),
]
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD])
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
"reorder_strategy, stripe_size",
REORDER_STRATEGY,
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0
if reorder_strategy == ReorderStrategy.Striped:
seq_lens = shape[seq_dim]
if seq_lens < (cp_size * stripe_size):
pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}")
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size)
assert jnp.array_equal(inversed, ref)
......@@ -352,6 +352,8 @@ class FusedAttnRunner:
bias_shape: BiasShape
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
stripe_size: int | None = None
num_segments_per_seq: int | None = None
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
......@@ -366,6 +368,14 @@ class FusedAttnRunner:
# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None
def __post_init__(self):
# Reset defaults for num_segments_per_seq if not explicitly passed
if self.num_segments_per_seq is None:
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
else:
self.num_segments_per_seq = 1
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
......@@ -577,7 +587,6 @@ class FusedAttnRunner:
return segment_ids, segment_pos, segment_pad
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
......@@ -603,7 +612,6 @@ class FusedAttnRunner:
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
self.batch_size, self.max_seqlen_q, pad_ratio
)
......@@ -635,12 +643,14 @@ class FusedAttnRunner:
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
else:
# no-ops for non cp or non load balanced
......@@ -771,7 +781,7 @@ class FusedAttnRunner:
def test_forward(self):
"""
Test forward without JIT
Test forward with JITted primitive and unJITted reference
"""
self._setup_inputs()
......@@ -801,6 +811,7 @@ class FusedAttnRunner:
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}
customcall_fused_dpa_jit = jit(
......@@ -896,6 +907,7 @@ class FusedAttnRunner:
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}
# We can compute dBias only for the [1, h, s, s] layout
......
......@@ -386,23 +386,57 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
def reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
):
"""Reorders a tensor for load balancing the compute of causal attention."""
if strategy == ReorderStrategy.DualChunkSwap:
if stripe_size is not None:
raise ValueError(
f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
" None"
)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
# stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if stripe_size is not None and stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP striped reordering {stripe_size=}. stripe_size must be a"
" positive integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size = 1 if stripe_size is None else stripe_size
return tex.attention.reorder_causal_striped(
tensor, cp_size, seq_dim, False, effective_stripe_size
)
raise ValueError(f"Unsupported {strategy=}")
def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
):
"""Inverse operation of `reorder_causal_load_balancing`."""
if strategy == ReorderStrategy.DualChunkSwap:
if stripe_size is not None:
raise ValueError(
f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
" None"
)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
# stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if stripe_size is not None and stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
" integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size = 1 if stripe_size is None else stripe_size
return tex.attention.reorder_causal_striped(
tensor, cp_size, seq_dim, True, effective_stripe_size
)
raise ValueError(f"Unsupported {strategy=}")
......@@ -988,7 +1022,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -1008,6 +1042,7 @@ def _fused_attn(
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
context_checkpoint_name: str = "context",
stripe_size: int | None = None,
):
output, _ = _fused_attn_fwd_rule(
qkv,
......@@ -1028,6 +1063,7 @@ def _fused_attn(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
)
return output
......@@ -1051,6 +1087,7 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
stripe_size,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
......@@ -1070,6 +1107,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
)
output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
......@@ -1099,6 +1137,7 @@ def _fused_attn_bwd_rule(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
stripe_size,
ctx,
dz,
):
......@@ -1133,6 +1172,7 @@ def _fused_attn_bwd_rule(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
......@@ -1169,6 +1209,7 @@ def fused_attn(
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
stripe_size: int | None = None,
):
"""
Perform cuDNN fused attention.
......@@ -1206,6 +1247,11 @@ def fused_attn(
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
stripe_size (int | None):
Indicates the striping size to be used when using ReorderStrategy.Striped.
Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1
is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring)
None indicates no striping strategy
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -1283,5 +1329,6 @@ def fused_attn(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
)
return output
......@@ -73,6 +73,7 @@ __all__ = [
"context_parallel_load_balanced",
"cp_axis",
"cp_striped_window_size",
"stripe_size",
],
)
@dataclass(frozen=True)
......@@ -92,7 +93,10 @@ class _FusedAttnConfig:
window_size: Tuple[int, int]
context_parallel_load_balanced: bool
cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA
stripe_size: (
int | None
) # Only for CP + Striped. For Ring P2P, stripe_size=1 only.For AG, stripe_size>=1.
@dataclass(frozen=True)
......@@ -527,7 +531,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
segment_ids=(_q_segment_ids, _kv_segment_ids),
segment_pos=(_q_segment_pos, _kv_segment_pos),
)
(q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
sequence_descriptor.get_seqlens_and_offsets(
config.attn_mask_type,
......@@ -536,7 +539,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.max_segments_per_seq,
)
)
if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1):
......@@ -1234,31 +1236,38 @@ def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contig
return combined.reshape(ori_tensor_shape)
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
def reorder_causal_striped(
tensor, cp_size: int, seq_dim: int, is_inverse: bool, stripe_size: int = 1
):
"""Reorders a tensor for load balancing with striped pattern"""
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
if stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
" integer"
)
if origin_shape[seq_dim] % (cp_size * stripe_size) != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
"Expected origin_shape[seq_dim] is multiple of cp_size*stripe_size but got"
f" {origin_shape[seq_dim]=}, {cp_size=}, {stripe_size=}, {cp_size*stripe_size=}"
)
if not is_inverse:
new_shape = [
*origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size],
*[origin_shape[seq_dim] // (cp_size * stripe_size), cp_size, stripe_size],
*origin_shape[seq_dim + 1 :],
]
else:
new_shape = [
*origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size],
*[cp_size, origin_shape[seq_dim] // (cp_size * stripe_size), stripe_size],
*origin_shape[seq_dim + 1 :],
]
chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)
striped_tensor = tensor.reshape(new_shape)
reordered_striped_tensor = jnp.swapaxes(striped_tensor, seq_dim, seq_dim + 1)
return reordered_striped_tensor.reshape(origin_shape)
@dataclass(frozen=True)
......@@ -1272,26 +1281,47 @@ class _FusedAttnCPWithAllGatherHelper:
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention"
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
allowed_layouts = [
QKVLayout.BSHD_BS2HD,
QKVLayout.BSHD_BSHD_BSHD,
QKVLayout.THD_T2HD,
QKVLayout.THD_THD_THD,
]
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if (not self.config.qkv_layout.is_thd() and self.config.stripe_size is not None) or (
self.config.qkv_layout.is_thd() and self.config.stripe_size is None
):
raise ValueError(
f"{header} only supports Dual Chunk load balancing with BSHD layouts and Striped"
" load balancing with THD layouts"
)
if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.qkv_layout.is_thd():
allowed_masks.append(AttnMaskType.PADDING_CAUSAL_MASK)
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
)
# Do not allow CP + AG + THD + Striped with NO_MASK
if (
self.config.attn_mask_type is not AttnMaskType.PADDING_CAUSAL_MASK
and self.config.qkv_layout.is_thd()
):
raise ValueError(f"{header} only supports PADDING_CAUSAL_MASK for THD types")
if self.config.max_segments_per_seq != 1:
if self.config.max_segments_per_seq != 1 and (not self.config.qkv_layout.is_thd):
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f"{header} only supports max_segments_per_seq == 1 for BSHD layouts, got:"
f" {self.config.max_segments_per_seq}"
)
......@@ -1305,10 +1335,25 @@ class _FusedAttnCPWithAllGatherHelper:
def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
if (
self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK
and not self.config.qkv_layout.is_thd()
): # BSHD AG case only
return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
if (
self.config.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.config.qkv_layout.is_thd()
): # THD AG case only
return AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type
def get_adjusted_max_segments_per_seq(self, max_seqlen, cp_size):
"""Converts the max segments per seq for context parallelism AG + THD."""
# Estimating adjusted max segments per seq
return (
max_seqlen // (self.config.stripe_size * cp_size)
) + self.config.max_segments_per_seq
def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
......@@ -1324,10 +1369,29 @@ class _FusedAttnCPWithAllGatherHelper:
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
stripe_size=self.config.stripe_size,
)
def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size),
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
stripe_size=self.config.stripe_size,
)
def all_gather_kv(self, k, v):
"""Performs a all-gather of k and v over context parallel ranks."""
"""Performs an all-gather of k and v over context parallel ranks."""
def ag(x):
x = lax_paral_op(
......@@ -1335,7 +1399,10 @@ class _FusedAttnCPWithAllGatherHelper:
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
if self.config.qkv_layout.is_thd():
x = reorder_causal_striped(x, cp_size, 1, True, self.config.stripe_size)
else:
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
return x
if self.config.qkv_layout.is_kvpacked():
......@@ -1345,13 +1412,36 @@ class _FusedAttnCPWithAllGatherHelper:
return k, v # fall through
def all_gather_segment_ids_and_pos(self, kv_segment_ids, kv_segment_pos):
"""Performs an all-gather of kv segment ids and kv segment pos over context parallel ranks."""
kv_segment_ids = lax_paral_op(
kv_segment_ids, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
kv_segment_pos = lax_paral_op(
kv_segment_pos, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
if self.config.qkv_layout.is_thd():
kv_segment_ids_ag = reorder_causal_striped(
kv_segment_ids, cp_size, 1, True, self.config.stripe_size
)
kv_segment_pos_ag = reorder_causal_striped(
kv_segment_pos, cp_size, 1, True, self.config.stripe_size
)
return kv_segment_ids_ag, kv_segment_pos_ag
return kv_segment_ids, kv_segment_pos # fall through
def reduce_scatter_dkv(self, dk, dv):
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
def rs(x):
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
if self.config.qkv_layout.is_thd():
x = reorder_causal_striped(x, cp_size, 1, False, self.config.stripe_size)
else:
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
return lax_paral_op(
x,
......@@ -1424,6 +1514,227 @@ class _FusedAttnCPWithAllGatherHelper:
return dk, dv # fall through
# Below are the sharded post AG q seg ids and pos for a given rank:
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
# segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, True, True, True]]
# seqlens_pre = [[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 0, 0, 0, 0]]
# seqlens_all_pad_neg = [[ 4, 4, 4, -1, -1, -1, -1]]
def q_seqlens_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
"""Extract the q seqlens for striped primitive (post AG) from the sharded q seg ids and seg pos"""
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask = q_segment_ids != 0
max_size = q_segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)
# Pick non-zero seg ids and seg pos using take_along_axis to index within the seg ids and pos
# Clip -1 to 0 for safe indexing
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(q_segment_ids, clipped_indices, axis=-1), 0
)
valid_segment_pos = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(q_segment_pos, clipped_indices, axis=-1), 0
)
# Create a mask for actual valid entries (not padding)
actual_valid = valid_segment_ids != 0
# First element is True only if it's actually valid
first_is_segment = actual_valid[..., 0:1]
# Detect segment breaks in the valid tokens only (not full seq)
# Padding will always be true as the segment change condition is being applied
# on the valid segments (which have padding at the end so they'll always trigger True)
segment_changes = jnp.concatenate(
[
first_is_segment, # First valid element starts a segment
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
],
axis=-1,
)
new_segment_ids = jnp.cumsum(segment_changes, axis=-1)
seqlens_pre = jax.vmap(
lambda av_row, nsi_row: jnp.where(av_row, nsi_row, 0).astype(jnp.int32)
)(actual_valid, new_segment_ids)
seqlens_all = jax.vmap(
lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:]
)(seqlens_pre)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
return seqlens_all_pad_neg
# Below are the sharded post AG q seg ids and pos for a given rank:
# q_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# q_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# segment_changes = [[ True, False, False, False, True, False, False, False, True, False, False, False, True, False, False, False]]
# segment_changes_masked = [[ True, False, False, False, False, False, False, False, True, False, False, False, True, False, False, False]]
# seq_offsets = [[ 0, 8, 12, -1, -1, -1, -1, -1]]
def q_seqoffsets_for_striped_for_rank(self, q_segment_ids, q_segment_pos, max_segments_per_seq):
"""Extract the q seqoffets for striped primitive (post AG) from the sharded q seg ids and seg pos"""
segment_changes = jnp.concatenate(
[
jnp.full(
(q_segment_pos.shape[0], 1), True, dtype=bool
), # First valid element starts a segment
(q_segment_pos[..., 1:] != q_segment_pos[..., :-1] + 1), # Segment pos changed
],
axis=-1,
)
# Remove any padded region segment changes
segment_changes_masked = jnp.where(q_segment_ids != 0, segment_changes, False)
# Get the indices for segment changes (these are the offsets)
seq_offsets = jax.vmap(
lambda scm_row: jnp.where(scm_row, size=max_segments_per_seq, fill_value=-1)[0]
)(segment_changes_masked)
return seq_offsets
# Below are the sharded post AG q seg ids and pos for a given rank:
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# non_zero_mask = [[ True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True]]
# non_zero_indices = [[ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1]]
# segment_changes = [[False, False, False, True, False, False, False, True, False, False, False, True, True, True, True, False]]
# selected_values = [[ 4, 15, 31, -1, -1, -1, -1, -1]]
def kv_seqlens_for_striped_for_rank(self, kv_segment_ids, kv_segment_pos, max_segments_per_seq):
"""Extract the kv seqlens for striped primitive (post AG) from the sharded kv seg ids and seg pos"""
# Create mask for non-zero seg ids and get the non-zero indices associated with the same
non_zero_mask = kv_segment_ids != 0
max_size = kv_segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)
# Pick non zero seg ids and seg pos using take_along_axis
# Clip -1 to 0 for safe indexing
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(kv_segment_ids, clipped_indices, axis=-1), 0
)
valid_segment_pos = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(kv_segment_pos, clipped_indices, axis=-1), 0
)
actual_valid = valid_segment_ids != 0
# Detect segment breaks (only for non-zero segments)
segment_changes = jnp.concatenate(
[
(
(valid_segment_ids[..., 1:] != valid_segment_ids[..., :-1])
& actual_valid[..., 1:]
)
| (valid_segment_pos[..., 1:] != valid_segment_pos[..., :-1] + 1),
actual_valid[..., -1:],
],
axis=-1,
)
# Get the indices for segment changes
segment_changes_valid = jax.vmap(
lambda sc_row, av_row: jnp.where(
sc_row & av_row, size=max_segments_per_seq, fill_value=-1
)[0]
)(segment_changes, actual_valid)
safe_indices = jnp.maximum(segment_changes_valid, 0)
# Select values using take_along_axis per row
selected_values = jnp.where(
segment_changes_valid >= 0,
jnp.take_along_axis(valid_segment_pos, safe_indices, axis=-1) + 1,
-1,
)
return selected_values
# Below are the sharded post AG q seg ids and pos for a given rank:
# kv_segment_ids = [[1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2]]
# kv_segment_pos = [[0, 1, 2, 3, 16, 17, 18, 19, 11, 12, 13, 14, 27, 28, 29, 30]]
# kv_segment_ids_ag = [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
# 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
# kv_segment_pos_ag = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
# 18, 19, 20, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
# 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
# max_segments_per_seq = 7
# Below are some intermediate representations:
# segment_changes_first_true_masked = [[ True, False, False, False, False, False, False, False, True,
# False, False, False, True, False, False, False]]
# segment_changes_indices = [[ 0, 8, 12, -1, -1, -1, -1, -1, -1]]
# segment_ids = [[ 1, 2, 2, -1, -1, -1, -1, -1, -1]]
# segment_changes_ag_first_true_masked = [[ True, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, True, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False, False, False, False, False, False, False, False, False,
# False]
# segment_changes_ag_indices = [[ 0, 21, -1, -1, -1, -1, -1, -1, -1]]
# seq_offsets = [[ 0, 21, 21, -1, -1, -1, -1, -1, -1]]
def kv_seqoffsets_for_striped_for_rank(
self,
kv_segment_pos,
kv_segment_ids,
kv_segment_pos_ag,
kv_segment_ids_ag,
max_segments_per_seq,
):
"""Extract the kv seqoffsets for striped primitive (post AG) from the sharded kv seg ids and seg pos,
AG kv seg ids and seg pos."""
# Calculate the segment pos change mask
segment_changes_first_true = jnp.concatenate(
[
jnp.full(
(kv_segment_pos.shape[0], 1), True, dtype=bool
), # Assume valid element starts a segment and mask afterwards
(kv_segment_pos[..., 1:] != kv_segment_pos[..., :-1] + 1), # Segment pos changed
],
axis=-1,
)
segment_changes_first_true_masked = jnp.where(
kv_segment_ids != 0, segment_changes_first_true, False
)
# Get segment change indices for rank
segment_changes_indices = jax.vmap(
lambda sc_row: jnp.where(sc_row, size=max_segments_per_seq, fill_value=-1)[0]
)(segment_changes_first_true_masked)
# Get segment ids associated with the segment_changes_indices for rank
segment_ids = jax.vmap(
lambda sci_row, ksi_row: jnp.where(sci_row >= 0, ksi_row[sci_row], -1)
)(segment_changes_indices, kv_segment_ids)
# Get segment change indices for AG
segment_changes_ag_first_true = jnp.concatenate(
[
jnp.full(
(kv_segment_pos.shape[0], 1), True, dtype=bool
), # Assume valid element starts a segment and mask afterwards
(
kv_segment_pos_ag[..., 1:] != kv_segment_pos_ag[..., :-1] + 1
), # Segment pos changed
],
axis=-1,
)
segment_changes_ag_first_true_masked = jnp.where(
kv_segment_ids_ag != 0, segment_changes_ag_first_true, False
)
# Get segment change indices for AG
segment_changes_ag_indices = jax.vmap(
lambda scag_row: jnp.where(scag_row, size=max_segments_per_seq, fill_value=-1)[0]
)(segment_changes_ag_first_true_masked)
# Use the segment ids picked per rank to get the offsets from the AG indices
seq_offsets = jax.vmap(
lambda si_row, sca_row: jnp.where(si_row > 0, sca_row[si_row - 1], -1)
)(segment_ids, segment_changes_ag_indices)
return seq_offsets
class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
"""
......@@ -1501,7 +1812,6 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
q_seqlen_for_step = q_seqlen / (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
q_split[sub_idx],
k_unmasked,
......@@ -1722,6 +2032,314 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)
class FusedAttnCPStripedWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Attention Forward with Context Parallelism and Striped Load Balancing Primitive
This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[5] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(
q,
k,
v,
bias,
softmax_offset,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
): # pylint: disable=unused-argument
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# cuDNN does not support right-aligned masking with dynamic sequence length padding.
# Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
# to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
# Each rank receives the ag k and v along with the ag kv seg ids and kv seg offsets
# Each rank sees the sharded view for 5 tensors -> q, _q_segment_ids, _q_segment_pos,
# _kv_segment_ids, _kv_segment_pos -> Note these have also been reordered before passing in.
def _cross_attn(
q, k, v, bias, softmax_offset, kv_segment_ids_ag, kv_segment_pos_ag, seed
):
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
# does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it
kv_max_seqlen = k.shape[1]
# Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(
max_seqlen=kv_max_seqlen, cp_size=cp_size
)
q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(
_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq
)
q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(
q_segment_ids=_q_segment_ids,
q_segment_pos=_q_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(
kv_segment_ids=_kv_segment_ids,
kv_segment_pos=_kv_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(
kv_segment_pos=_kv_segment_pos,
kv_segment_ids=_kv_segment_ids,
kv_segment_pos_ag=kv_segment_pos_ag,
kv_segment_ids_ag=kv_segment_ids_ag,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
q, # sharded for rank
k, # ag
v, # ag
bias,
softmax_offset,
seed,
q_seqlens_for_rank,
kv_seqlens_for_rank,
q_seq_offsets_for_rank,
kv_seq_offsets_for_rank,
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
config=helper.get_step_config_for_striped(
max_seqlen=kv_max_seqlen, cp_size=cp_size
),
)
return output, softmax_aux, rng_state
# AG the k, v, kv_segment_ids and kv_segment_pos
k_ag, v_ag = helper.all_gather_kv(k, v)
_kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(
_kv_segment_ids, _kv_segment_pos
)
functions = [
partial(
_cross_attn,
q,
k_ag,
v_ag,
bias,
softmax_offset,
_kv_segment_ids_ag,
_kv_segment_pos_ag,
seed,
)
for _ in range(cp_size)
]
return lax.switch(cp_rank, functions)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPStripedWithAllGatherFwdPrimitive)
class FusedAttnCPStripedWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Attention Backward with Context Parallelism and Striped Load Balancing Primitive.
This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
The gradients are subsequently reduce-scattered back to each context parallel rank.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
# Ensure we can support this configuration with context parallelism.
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
softmax_offset_spec = get_padded_spec(arg_infos[4])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
dsoftmax_offset_sharding = NamedSharding(mesh, PartitionSpec(*softmax_offset_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
dq_sharding,
dk_sharding,
dv_sharding,
dbias_sharding,
dsoftmax_offset_sharding,
)
def impl(
q,
k,
v,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
): # pylint: disable=unused-argument
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
def _cross_attn_bwd(
q,
k,
v,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
_q_segment_ids,
kv_segment_ids_ag,
_q_segment_pos,
kv_segment_pos_ag,
):
# Helper generates the seqlens and offsets for q and kv and then pass them down to the FusedAttnFwdPrimitive
# Unset the segment_ids and segment_pos by passing placeholders so that the seqlens_from_segment_ids_pos()
# does not go down that route but instead just picks the pre-computed seqlens and offsets passed onto it
kv_max_seqlen = k.shape[1]
# Estimate an adjusted max_segments_per_seq per rank based on the global max_segments_per_seq
adjusted_max_segments_per_seq = helper.get_adjusted_max_segments_per_seq(
max_seqlen=kv_max_seqlen, cp_size=cp_size
)
q_seqlens_for_rank = helper.q_seqlens_for_striped_for_rank(
_q_segment_ids, _q_segment_pos, adjusted_max_segments_per_seq
)
q_seq_offsets_for_rank = helper.q_seqoffsets_for_striped_for_rank(
q_segment_ids=_q_segment_ids,
q_segment_pos=_q_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seqlens_for_rank = helper.kv_seqlens_for_striped_for_rank(
kv_segment_ids=_kv_segment_ids,
kv_segment_pos=_kv_segment_pos,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
kv_seq_offsets_for_rank = helper.kv_seqoffsets_for_striped_for_rank(
kv_segment_pos=_kv_segment_pos,
kv_segment_ids=_kv_segment_ids,
kv_segment_pos_ag=kv_segment_pos_ag,
kv_segment_ids_ag=kv_segment_ids_ag,
max_segments_per_seq=adjusted_max_segments_per_seq,
)
dq_local, dk_local, dv_local, dbias_local, _ = FusedAttnBwdPrimitive.impl(
q, # sharded for rank
k, # ag
v, # ag
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
q_seqlens_for_rank,
kv_seqlens_for_rank,
q_seq_offsets_for_rank,
kv_seq_offsets_for_rank,
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
jnp.zeros(0),
config=helper.get_step_config_for_striped(
max_seqlen=kv_max_seqlen, cp_size=cp_size
),
)
return dq_local, dk_local, dv_local, dbias_local
# AG the k, v, kv_segment_ids and kv_segment_pos
k_ag, v_ag = helper.all_gather_kv(k, v)
_kv_segment_ids_ag, _kv_segment_pos_ag = helper.all_gather_segment_ids_and_pos(
_kv_segment_ids, _kv_segment_pos
)
functions = [
partial(
_cross_attn_bwd,
q,
k_ag,
v_ag,
bias,
softmax_offset,
softmax_aux,
rng_state,
output,
doutput,
_q_segment_ids,
_kv_segment_ids_ag,
_q_segment_pos,
_kv_segment_pos_ag,
)
for _ in range(cp_size)
]
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
# RS the dk and dv
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
# Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset = jnp.empty_like(softmax_offset)
return dq, dk, dv, dbias, dummy_dsoftmax_offset
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPStripedWithAllGatherBwdPrimitive)
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
"""Helper class to assist with running the P2P ring strategy for CP attention."""
......@@ -1811,6 +2429,7 @@ class _FusedAttnCPWithP2PHelper:
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
stripe_size=self.config.stripe_size,
)
def stack_kv(self, k, v):
......@@ -2693,6 +3312,7 @@ def fused_attn_fwd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
stripe_size: int | None = None,
) -> jnp.ndarray:
"""
Perform the forward pass of with cuDNN fused attention implementations.
......@@ -2731,6 +3351,7 @@ def fused_attn_fwd(
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
......@@ -2796,12 +3417,16 @@ def fused_attn_fwd(
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
stripe_size=stripe_size,
)
primitive = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
if qkv_layout.is_thd():
primitive = FusedAttnCPStripedWithAllGatherFwdPrimitive.outer_primitive
else:
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING:
# We must use stripe attention for THD-RING
if qkv_layout.is_thd():
......@@ -2843,6 +3468,7 @@ def fused_attn_bwd(
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
stripe_size: int | None = None,
):
"""
Perform the backward pass of the cuDNN fused attention implementations.
......@@ -2882,6 +3508,7 @@ def fused_attn_bwd(
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
stripe_size (int | None): Indicates the striping height to be used for ReorderStrategy.Striped Load Balancing
Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
......@@ -2954,12 +3581,16 @@ def fused_attn_bwd(
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
stripe_size=stripe_size,
)
primitive = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
if qkv_layout.is_thd():
primitive = FusedAttnCPStripedWithAllGatherBwdPrimitive.outer_primitive
else:
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING:
if qkv_layout.is_thd():
primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
......
......@@ -176,6 +176,9 @@ _primitive_registry = {}
def register_primitive(cls, outer_only=False):
"""
Register a JAX primitive and add it to the internal registry.
Inner primitive - single device, no sharding awareness, eager mode fallback
Outer primitive - multi device, sharding aware, partition() distributes work,
used when there's a dev mesh context
"""
_primitive_registry[cls.__name__] = cls
......@@ -190,14 +193,17 @@ def register_primitive(cls, outer_only=False):
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
# Define eager execution implementation (by invoking it's MLIR lowering)
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract)
mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
cls.inner_primitive = inner_p
# Create the outer primitive for distributed execution
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
# Define the eager execution implementation
outer_p.def_impl(cls.outer_impl)
outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment