"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "660a7fcfa40d62305ecba6bc6352c4026d56d680"
Unverified Commit fd0cd12e authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add CP + THD + AG + Striped>1 + SWA support (#2379)



* Add generic stripe_height support for load balancing
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix imports in test for deprecated jax.experimental.pjit
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add test case for stripe_height greater than 1. Add stripe_height arg to reordering methods
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add Striped 1 and 4 test cases. Refactor the Load Balancing test case. Fix the incorrect shape in striping inverser reordering
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Modify test code for CP + AG + THD + stripe height greater than 1
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add stripe_height arg to fused attn and fused attn fwd API. Add appropriate mask checks for AG+THD+CP and pick BRCM to be executed per rank. Add Fused Attn Primitive for CP + THD +AG + Striping. Add a method to reorder and all gather segment ids and offsets for kv
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* TMP: Throwaway testing commit
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add comments in primitive registration process
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* TMP: Throwaway test commit
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Undoing incorrect rebase/merge leftovers
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* TMP: Throwaway test commits
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Add support for calculating q and kv seqlens and offsets per rank for CP+THD+AG+SW+Striped>1 primitive
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Augment jax primitive register code comments
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* Fix the array sizes and padding values returned for seqlens and offsets to fit what the fused attn primitive non cp computation
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add support in new primitive for softmax_offset related changes. Put in missing primitive registering line in again. Increase the seqoffsets arrays lengths by 1
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Add new set of helper functions for seqlens and seqoffsets fo AG+THD+CP+Stripe>1 which accounts for batching and seq offsets size b+1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add backward primitive for CP+THD+AG+Striped>1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Modify tests for backward primitive for CP+THD+AG+Striped>1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Move stripe_height along with other static args in fused_attn_bwd rule. Fix typo in CP+AG+TH+Striped>1 primitive
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Code clean up: remove older version for calculating seqlens and offsets for CP+AG+THD+striped>1 primitive
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add test for CP+THD+AG+Striped>1
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix missing var
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add SWA tests for AG+Striped>1+CP+THD+SWA
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Restoring test code
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove assert preventing SWA code path in CP+AG+Striped primitive
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Parametrize num_segments_per_seq in tests
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Clean up test code
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Clean up test code in TE common
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Clean up debug statements
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Rename stripe_height to stripe_size
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Code clean up and add additional comments
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

nit: Apply suggestions from code review
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

Fix type on fused attn tests
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Fix seqoffsets length to be passed onto FusedAttn primitive as it is b and not b+1 needed by cuDNN
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove commented code
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

Fix linting issues
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Fix incorrect greptile change
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip THD test cases for CP + AG + Dual chunk. Skip BSHD cases for CP + AG + Striped>1. Correct the layout and shapr parameters passed to the tests
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Pass stripe_size explicitly for ring attn tests for THD cases
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Remove TODO
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* Explicitly fail if THD + AG is being used with a non padding causal mask
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* nit: Correct the ID for the test dist fused attn tests to account for cp*2 which is done under the hood
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Set num_segments_per_seq defaults to None instead of 0
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Augment comments. Add ValueError for stripe_size=0
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Test only 1 num_segments_per_seq combination for CP+AG+THD+Striped>1+SWA instead of 2. Modify the num segments and window size to easily to debug values
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Default stripe_size to None instead of 0. Modify stripe_size check for <=0 instead of ==0
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove incorrectly added file
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Explicitly pass zero sized arrays for seg ids and pos in the CP + AG + Striped primitive rather than using the seqlens or the offsets as placeholders
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add a deep dive doc for CP+THD+AG+Stripe>1+SWA regarding design considerations and decisions
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Put docs and pngs into it's separate dir
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Replace png screenshots with markdown coe blocks for the attention patterns. Remove unecessary pngs
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Add doc file to index.rst. Fix grammatical errors
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarKshitij  Janardan Lakhani <klakhani@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent f0572aa5
This diff is collapsed.
......@@ -56,3 +56,4 @@ Transformer Engine documentation
api/c/index
debug
examples/attention/attention.ipynb
examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb
......@@ -327,9 +327,9 @@ DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS = [
]
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES = [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
# Sequence lengths will be scaled by CP*2 so that we don't run with tiny sizes.
pytest.param([2, 128, 8, 128], id="2-128xCPx2-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCPx2-16-64"),
]
......@@ -351,12 +351,14 @@ class TestDistributedContextParallelSelfAttn:
use_shardy,
use_scan_ring=False,
window_size=None,
stripe_size=None,
num_segments_per_seq=None,
):
if qkv_layout.is_thd():
if cp_strategy == CPStrategy.ALL_GATHER:
pytest.skip("THD doesn't support all gather context parallelism.")
if not load_balanced and cp_strategy == CPStrategy.RING:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
if not load_balanced and (
cp_strategy == CPStrategy.RING or cp_strategy == CPStrategy.ALL_GATHER
):
pytest.skip(f"THD + {cp_strategy=} doesn't support unbalanced context parallelism.")
assert not use_scan_ring or cp_strategy == CPStrategy.RING
......@@ -382,7 +384,6 @@ class TestDistributedContextParallelSelfAttn:
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
runner = FusedAttnRunner(
batch,
seqlen,
......@@ -401,6 +402,8 @@ class TestDistributedContextParallelSelfAttn:
bias_shape,
window_size,
SeqDescFormat.SegmentIDs,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -453,7 +456,7 @@ class TestDistributedContextParallelSelfAttn:
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES)
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
......@@ -470,6 +473,8 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
):
if qkv_layout.is_thd():
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
kv_groups = 8
self.impl_test_context_parallel_attn(
device_count,
......@@ -486,6 +491,72 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=True,
)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
)
@pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"qkv_layout, attn_mask_type",
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS,
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(True, id="BALANCED")],
)
@pytest.mark.parametrize(
"stripe_size",
[pytest.param(64, id="STRIPE-64"), pytest.param(128, id="STRIPE-128")],
)
@pytest.mark.parametrize(
"window_size",
[
pytest.param((-1, -1), id="window_size(-1, -1)"),
pytest.param((5, 0), id="window_size(8, 0)"),
],
)
@pytest.mark.parametrize(
"num_segments_per_seq",
[pytest.param(5, id="SEG-5")],
)
def test_context_parallel_allgather_striped_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
window_size,
stripe_size,
num_segments_per_seq,
):
if not qkv_layout.is_thd():
pytest.skip("Only THD layout is supported for CP + AG + Striped attention")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
use_shardy=False,
window_size=window_size,
stripe_size=stripe_size,
num_segments_per_seq=num_segments_per_seq,
)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_context_parallel_configs_for_attn(),
......@@ -514,6 +585,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention")
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -577,6 +650,8 @@ class TestDistributedContextParallelSelfAttn:
"When context parallelism and sliding window attention are used, "
"scanloop is not supported"
)
# Set the stripe size to 1 (ring attention only support stripe_size=1)
stripe_size = 1 if qkv_layout.is_thd() else None
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -592,6 +667,7 @@ class TestDistributedContextParallelSelfAttn:
use_shardy=False,
use_scan_ring=use_scan,
window_size=window_size,
stripe_size=stripe_size,
)
@pytest_parametrize_wrapper(
......@@ -616,6 +692,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
):
kv_groups = 8
# Set the stripe size to 1 (ring attention only support stripe_size=1)
stripe_size = 1 if qkv_layout.is_thd() else None
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -630,6 +708,7 @@ class TestDistributedContextParallelSelfAttn:
cp_strategy=CPStrategy.RING,
use_shardy=False,
use_scan_ring=True,
stripe_size=stripe_size,
)
......@@ -639,31 +718,39 @@ REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = {
"L2": [[4, 32, 12, 32], [1, 16, 1, 1]],
}
REORDER_STRATEGY = [
pytest.param(ReorderStrategy.DualChunkSwap, None, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, 1, id="Striped-1"),
pytest.param(ReorderStrategy.Striped, 4, id="Striped-4"),
]
class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
@pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD, QKVFormat.THD])
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
"reorder_strategy, stripe_size",
REORDER_STRATEGY,
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
def test(self, cp_size, shape, qkv_format, reorder_strategy, stripe_size):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0
if reorder_strategy == ReorderStrategy.Striped:
seq_lens = shape[seq_dim]
if seq_lens < (cp_size * stripe_size):
pytest.skip(f"{seq_lens=} must be larger than {cp_size*stripe_size=}")
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3, 4])
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim, stripe_size)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim, stripe_size)
assert jnp.array_equal(inversed, ref)
......@@ -352,6 +352,8 @@ class FusedAttnRunner:
bias_shape: BiasShape
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
stripe_size: int | None = None
num_segments_per_seq: int | None = None
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
......@@ -366,6 +368,14 @@ class FusedAttnRunner:
# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None
def __post_init__(self):
# Reset defaults for num_segments_per_seq if not explicitly passed
if self.num_segments_per_seq is None:
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
else:
self.num_segments_per_seq = 1
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
......@@ -577,7 +587,6 @@ class FusedAttnRunner:
return segment_ids, segment_pos, segment_pad
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
......@@ -603,7 +612,6 @@ class FusedAttnRunner:
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
self.batch_size, self.max_seqlen_q, pad_ratio
)
......@@ -635,12 +643,14 @@ class FusedAttnRunner:
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
else:
# no-ops for non cp or non load balanced
......@@ -771,7 +781,7 @@ class FusedAttnRunner:
def test_forward(self):
"""
Test forward without JIT
Test forward with JITted primitive and unJITted reference
"""
self._setup_inputs()
......@@ -801,6 +811,7 @@ class FusedAttnRunner:
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}
customcall_fused_dpa_jit = jit(
......@@ -896,6 +907,7 @@ class FusedAttnRunner:
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}
# We can compute dBias only for the [1, h, s, s] layout
......
......@@ -386,23 +386,57 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
def reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
):
"""Reorders a tensor for load balancing the compute of causal attention."""
if strategy == ReorderStrategy.DualChunkSwap:
if stripe_size is not None:
raise ValueError(
f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
" None"
)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
# stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if stripe_size is not None and stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP striped reordering {stripe_size=}. stripe_size must be a"
" positive integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size = 1 if stripe_size is None else stripe_size
return tex.attention.reorder_causal_striped(
tensor, cp_size, seq_dim, False, effective_stripe_size
)
raise ValueError(f"Unsupported {strategy=}")
def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int, stripe_size: int | None = None
):
"""Inverse operation of `reorder_causal_load_balancing`."""
if strategy == ReorderStrategy.DualChunkSwap:
if stripe_size is not None:
raise ValueError(
f"Incorrect value for CP dual chunk reordering {stripe_size=}. stripe_size must be"
" None"
)
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
# stripe_size > 1 is only supported for CP+THD+AG+Striped>1+SWA
# stripe_size = 128 is recommended for CP+THD+AG+Striped>1+SWA
if stripe_size is not None and stripe_size <= 0:
raise ValueError(
f"Incorrect value for CP reordering {stripe_size=}. stripe_size must be a positive"
" integer"
)
# Supporting old API defaults of stripe_size=1
effective_stripe_size = 1 if stripe_size is None else stripe_size
return tex.attention.reorder_causal_striped(
tensor, cp_size, seq_dim, True, effective_stripe_size
)
raise ValueError(f"Unsupported {strategy=}")
......@@ -988,7 +1022,7 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
......@@ -1008,6 +1042,7 @@ def _fused_attn(
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
context_checkpoint_name: str = "context",
stripe_size: int | None = None,
):
output, _ = _fused_attn_fwd_rule(
qkv,
......@@ -1028,6 +1063,7 @@ def _fused_attn(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
)
return output
......@@ -1051,6 +1087,7 @@ def _fused_attn_fwd_rule(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
stripe_size,
):
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
......@@ -1070,6 +1107,7 @@ def _fused_attn_fwd_rule(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
)
output = checkpoint_name(output, context_checkpoint_name)
softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name)
......@@ -1099,6 +1137,7 @@ def _fused_attn_bwd_rule(
context_parallel_causal_load_balanced,
context_parallel_axis,
context_checkpoint_name,
stripe_size,
ctx,
dz,
):
......@@ -1133,6 +1172,7 @@ def _fused_attn_bwd_rule(
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
stripe_size=stripe_size,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
......@@ -1169,6 +1209,7 @@ def fused_attn(
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
stripe_size: int | None = None,
):
"""
Perform cuDNN fused attention.
......@@ -1206,6 +1247,11 @@ def fused_attn(
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
stripe_size (int | None):
Indicates the striping size to be used when using ReorderStrategy.Striped.
Currently, a stripe_size > 1 is only supported for CP + THD + Striped + AG, whereas a stripe_size=1
is supported for both, CP + THD + Striped + AG and CP + THD + Striped + P2P(Ring)
None indicates no striping strategy
Returns:
(jnp.ndarray): The output tensor from the fused attention.
......@@ -1283,5 +1329,6 @@ def fused_attn(
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
context_checkpoint_name=context_checkpoint_name,
stripe_size=stripe_size,
)
return output
......@@ -176,6 +176,9 @@ _primitive_registry = {}
def register_primitive(cls, outer_only=False):
"""
Register a JAX primitive and add it to the internal registry.
Inner primitive - single device, no sharding awareness, eager mode fallback
Outer primitive - multi device, sharding aware, partition() distributes work,
used when there's a dev mesh context
"""
_primitive_registry[cls.__name__] = cls
......@@ -190,14 +193,17 @@ def register_primitive(cls, outer_only=False):
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
# Define eager execution implementation (by invoking it's MLIR lowering)
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract)
mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
cls.inner_primitive = inner_p
# Create the outer primitive for distributed execution
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
# Define the eager execution implementation
outer_p.def_impl(cls.outer_impl)
outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment