Unverified Commit c5d6a069 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] THD ring attention (#1454)



* Support THD + ring attention for self attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Consolidate reorder strategy
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix dataclass frozen issue
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extension
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine P2P helper check_supported
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add segment_ids/pos check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fixup
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dual chunk swap example
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Align different reorder code structure
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4b523d29
......@@ -23,6 +23,7 @@ from transformer_engine.jax.attention import (
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
)
......@@ -210,29 +211,29 @@ class TestDistributedCrossAttn:
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"attn_mask_type",
"qkv_layout, attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
class TestDistributedContextParallelSelfAttn:
......@@ -265,7 +266,6 @@ class TestDistributedContextParallelSelfAttn:
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
runner = FusedAttnRunner(
batch,
......@@ -282,7 +282,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
bias_shape,
None,
SeqDescFormat.Seqlens,
SeqDescFormat.SegmentIDs,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -297,7 +297,7 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
mask_type,
dropout_prob,
num_head,
num_kv_heads,
......@@ -340,6 +340,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -377,7 +379,10 @@ class TestDistributedContextParallelSelfAttn:
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
self.impl_test_context_parallel_attn(
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -404,17 +409,26 @@ class TestReorderCausalLoadBalancing:
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
assert jnp.array_equal(inversed, ref)
......@@ -28,12 +28,14 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
ReorderStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
......@@ -347,9 +349,9 @@ class FusedAttnRunner:
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
self.qkv_layout.value,
self.attn_bias_type.value,
self.attn_mask_type.value,
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
......@@ -500,7 +502,8 @@ class FusedAttnRunner:
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD:
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
......@@ -536,6 +539,30 @@ class FusedAttnRunner:
self.window_size,
)
if self.cp_size > 1 and self.cp_load_balanced:
if self.qkv_layout.is_thd():
reorder_strategy = ReorderStrategy.Striped
else:
reorder_strategy = ReorderStrategy.DualChunkSwap
seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
# Test different input formats
if self.qkv_layout.is_thd():
match self.seq_desc_format:
......@@ -548,8 +575,14 @@ class FusedAttnRunner:
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
(self.segment_pos_q, self.segment_pos_kv),
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
......@@ -605,7 +638,12 @@ class FusedAttnRunner:
case _:
def to_dp_shardings(x):
if x.ndim == 1:
pspec = PartitionSpec(self.mesh_resource.dp_resource)
else:
pspec = PartitionSpec(
self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
)
return NamedSharding(self.mesh, pspec)
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
......@@ -637,24 +675,6 @@ class FusedAttnRunner:
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
# Softmax aux sharding
if self.cp_size > 1 and self.cp_load_balanced:
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
def test_forward(self):
"""
Test forward without JIT
......@@ -733,14 +753,23 @@ class FusedAttnRunner:
self._setup_inputs()
def grad_func(func, *args, **kwargs):
def grad_func(func, *args, cp_reverse_out=False, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if self.attn_mask_type.is_causal():
gradient_multiplier /= 10
# Keep only valid result for the gradient
if not cp_reverse_out:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
func(*args, **kwargs),
)
else:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
self.cp_inverse_reorder_fn(func(*args, **kwargs)),
)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
......@@ -787,7 +816,7 @@ class FusedAttnRunner:
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
),
arg_nums,
),
......
......@@ -135,6 +135,39 @@ class QKVLayout(Enum):
"""
return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
def to_qkvpacked(self):
"""
Return the corresponding qkvpacked format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BS3HD
if qkv_format == QKVFormat.THD:
return QKVLayout.T3HD
raise ValueError(f"Unsupported {qkv_format=}")
def to_kvpacked(self):
"""
Return the corresponding kvpacked format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BSHD_BS2HD
if qkv_format == QKVFormat.THD:
return QKVLayout.THD_T2HD
raise ValueError(f"Unsupported {qkv_format=}")
def to_separate(self):
"""
Return the corresponding separate format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BSHD_BSHD_BSHD
if qkv_format == QKVFormat.THD:
return QKVLayout.THD_THD_THD
raise ValueError(f"Unsupported {qkv_format=}")
class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
......@@ -149,6 +182,28 @@ class CPStrategy(Enum):
RING = 2
class ReorderStrategy(Enum):
"""
Defines the tokens re-order strategy for context parallel load balancing for causal mask.
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
mulitple of 2 * cp_size.
Examples:
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
- Striped: This strategy distributes the tokens in a striped (interleaved) manner across
the sequence. This is currently used for THD load balance.
Example: Consider 4 GPUs with seqlens=16.
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15]
- After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15]
"""
DualChunkSwap = 0
Striped = 1
def make_swa_mask(
segment_pos_q: jnp.ndarray,
segment_pos_kv: jnp.ndarray,
......@@ -243,9 +298,9 @@ def is_fused_attn_kernel_available(
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
qkv_layout.value,
attn_bias_type.value,
attn_mask_type.value,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_probability,
q_num_heads,
kv_num_heads,
......@@ -276,16 +331,24 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
"""Reorders a tensor for load balancing the compute of causal attention."""
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
raise ValueError(f"Unsupported {strategy=}")
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
):
"""Inverse operation of `reorder_causal_load_balancing`."""
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
raise ValueError(f"Unsupported {strategy=}")
def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
......@@ -412,8 +475,6 @@ class SequenceDescriptor:
"""
Acquire the seqlens/offsets for cuDNN backend
"""
attn_mask_type = AttnMaskType(attn_mask_type)
qkv_layout = QKVLayout(qkv_layout)
q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos
assert q_segment_ids.shape == q_segment_pos.shape
......@@ -589,9 +650,9 @@ def _legacy_fused_attn(
Intra-sequence padding is not valid. The padded tokens can only on the right-most.
Otherwise the results will be wrong.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
......@@ -608,16 +669,18 @@ def _legacy_fused_attn(
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
case QKVLayout.BS3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
case QKVLayout.BSHD_BS2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
case QKVLayout.BSHD_BSHD_BSHD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
case _:
raise ValueError(f"Unknown {qkv_layout=}")
# convert the mask to seqlens, mask doesn't support ragged offsets
if not attn_mask_type.is_padding():
......@@ -689,16 +752,18 @@ def fused_attn_thd(
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
case QKVLayout.T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_T2HD:
case QKVLayout.THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
case QKVLayout.THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
case _:
raise ValueError(f"Unknown {qkv_layout=}")
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
assert q_seq_lens.shape == (batch, q_max_seqlen)
......@@ -789,9 +854,9 @@ def _fused_attn_fwd_rule(
bias,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
......@@ -845,9 +910,9 @@ def _fused_attn_bwd_rule(
output,
dz,
sequence_descriptor,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
......@@ -903,9 +968,9 @@ def fused_attn(
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial, reduce
import operator
import os
......@@ -17,17 +17,18 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
CPStrategy,
SequenceDescriptor,
)
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Fused_Attn_Backend,
nvte_get_qkv_format,
)
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
......@@ -79,9 +80,9 @@ class _FusedAttnConfig:
Passes static configuration properties of fused attention.
"""
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
qkv_layout: NVTE_QKV_Layout
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
qkv_layout: QKVLayout
scaling_factor: float
dropout_probability: float
is_training: bool
......@@ -99,9 +100,9 @@ class FusedAttnHelper:
q_dtype: jnp.dtype
kv_dtype: jnp.dtype
qkv_layout: NVTE_QKV_Layout
attn_bias_type: NVTE_Bias_Type
attn_mask_type: NVTE_Mask_Type
qkv_layout: QKVLayout
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
dropout_probability: float
q_num_heads: int
kv_num_heads: int
......@@ -119,9 +120,9 @@ class FusedAttnHelper:
return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.qkv_layout.value,
self.attn_bias_type.value,
self.attn_mask_type.value,
self.dropout_probability,
self.q_num_heads,
self.kv_num_heads,
......@@ -140,23 +141,24 @@ class FusedAttnHelper:
@staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
raise NotImplementedError
if qkv_layout.is_qkvpacked():
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
kv_head_dim = q_head_dim
assert nqkv == 3
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
elif qkv_layout.is_kvpacked():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
assert nkv == 2
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
elif qkv_layout.is_separate():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape
case _:
assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}"
else:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
......@@ -310,7 +312,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -330,9 +332,9 @@ class FusedAttnFwdPrimitive(BasePrimitive):
head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
config.attn_bias_type.value,
config.attn_mask_type.value,
config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training,
config.max_segments_per_seq,
......@@ -385,7 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -419,9 +421,9 @@ class FusedAttnFwdPrimitive(BasePrimitive):
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
......@@ -511,7 +513,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
)
)
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD:
if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
......@@ -529,20 +531,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
)
return offsets_2d
match config.qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_T2HD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
q, k, v, config.qkv_layout
)
assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
kv_batch = q_batch = batch[0]
# Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
......@@ -610,28 +603,27 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
match config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
# q_spec = (...batch, q_seqlen, head, hidden)
if config.qkv_layout.is_qkvpacked():
# q_spec = (...batch, q_seqlen, 3, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
)
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
elif config.qkv_layout.is_kvpacked():
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
elif config.qkv_layout.is_separate():
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
case _:
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
......@@ -705,7 +697,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -725,9 +717,9 @@ class FusedAttnBwdPrimitive(BasePrimitive):
head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
config.attn_bias_type.value,
config.attn_mask_type.value,
config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training,
deterministic,
......@@ -787,7 +779,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
......@@ -824,9 +816,9 @@ class FusedAttnBwdPrimitive(BasePrimitive):
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
......@@ -922,7 +914,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
)
)
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD:
if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
......@@ -941,20 +933,11 @@ class FusedAttnBwdPrimitive(BasePrimitive):
)
return offsets_2d
match config.qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_T2HD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
q, k, v, config.qkv_layout
)
assert len(batch) == 1
kv_batch = q_batch = batch[0]
# Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
......@@ -1088,7 +1071,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config=config,
)
global_dbias = local_dbias
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
......@@ -1098,7 +1081,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive)
def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention."""
if cp_size == 1:
return tensor
......@@ -1108,7 +1091,7 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}")
raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
......@@ -1150,6 +1133,33 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu
return combined.reshape(ori_tensor_shape)
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
"""Reorders a tensor for load balancing with striped pattern"""
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
)
if not is_inverse:
new_shape = [
*origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size],
*origin_shape[seq_dim + 1 :],
]
else:
new_shape = [
*origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size],
*origin_shape[seq_dim + 1 :],
]
chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention."""
......@@ -1161,17 +1171,17 @@ class _FusedAttnCPWithAllGatherHelper:
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
......@@ -1189,8 +1199,8 @@ class _FusedAttnCPWithAllGatherHelper:
def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type
def get_step_config(self) -> _FusedAttnConfig:
......@@ -1217,13 +1227,12 @@ class _FusedAttnCPWithAllGatherHelper:
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
return x
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
if self.config.qkv_layout.is_kvpacked():
return ag(k), v
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
if self.config.qkv_layout.is_separate():
return ag(k), ag(v)
return k, v # fall through
......@@ -1234,7 +1243,7 @@ class _FusedAttnCPWithAllGatherHelper:
def rs(x):
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
return lax_paral_op(
x,
......@@ -1245,10 +1254,9 @@ class _FusedAttnCPWithAllGatherHelper:
tiled=True,
)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
if self.config.qkv_layout.is_kvpacked():
return rs(dk), dv
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
if self.config.qkv_layout.is_separate():
return rs(dk), rs(dv)
return dk, dv # fall through
......@@ -1286,10 +1294,9 @@ class _FusedAttnCPWithAllGatherHelper:
def sliced(x):
return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
if self.config.qkv_layout.is_kvpacked():
return sliced(k), v
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
if self.config.qkv_layout.is_separate():
return sliced(k), sliced(v)
return k, v # fall through
......@@ -1300,11 +1307,10 @@ class _FusedAttnCPWithAllGatherHelper:
def pad(x, npad):
return jnp.pad(x, npad, "constant", constant_values=0.0)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
if self.config.qkv_layout.is_kvpacked():
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
return pad(dk, npad), dv
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
if self.config.qkv_layout.is_separate():
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
return pad(dk, npad), pad(dv, npad)
......@@ -1378,7 +1384,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
results = []
for sub_idx in range(2):
if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK:
if config.attn_mask_type == AttnMaskType.NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked
else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
......@@ -1514,7 +1520,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
results = []
for sub_idx in range(2):
if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK:
if config.attn_mask_type == AttnMaskType.NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked
else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
......@@ -1544,7 +1550,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
)
# pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK:
if config.attn_mask_type != AttnMaskType.NO_MASK:
pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)
......@@ -1614,24 +1620,31 @@ class _FusedAttnCPWithP2PHelper:
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused ring attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
if self.config.qkv_layout.is_thd():
allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
else:
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
if self.config.qkv_layout.is_thd():
allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
else:
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
)
if self.config.max_segments_per_seq != 1:
if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
......@@ -1655,7 +1668,7 @@ class _FusedAttnCPWithP2PHelper:
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
......@@ -1668,20 +1681,18 @@ class _FusedAttnCPWithP2PHelper:
def stack_kv(self, k, v):
"""Stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=k.dtype)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
if self.config.qkv_layout.is_kvpacked():
return k
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
if self.config.qkv_layout.is_separate():
return jnp.stack([k, v], axis=2)
return _not_used
def unstack_kv(self, kv):
"""Un-stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=kv.dtype)
match self.config.qkv_layout:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
if self.config.qkv_layout.is_kvpacked():
return kv, _not_used
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
if self.config.qkv_layout.is_separate():
return jnp.unstack(kv, axis=2)
return _not_used, _not_used # fall through
......@@ -1803,8 +1814,8 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
)
return output_per_step, softmax_aux_per_step
causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK)
no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK)
causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
......@@ -1824,7 +1835,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
return output_per_step, softmax_aux_per_step
......@@ -1846,7 +1857,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
softmax_aux_per_step = jnp.concat(
......@@ -1865,7 +1876,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
)
return output_per_step, softmax_aux_per_step
if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
# This is for nested jax.lax.cond
def jax_cond_wrap():
if config.context_parallel_load_balanced:
......@@ -2019,8 +2030,8 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
)
return dq_per_step, dk_dv_per_step, dbias_per_step
causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK)
no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK)
causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
......@@ -2043,7 +2054,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
dk_dv_per_step = jnp.concat(
[dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
......@@ -2081,7 +2092,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
return dq_per_step, dk_dv_per_step, dbias_per_step
......@@ -2089,7 +2100,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def skip_compute():
return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)
if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
# This is for nested jax.lax.cond
def jax_cond_wrap():
if config.context_parallel_load_balanced:
......@@ -2107,7 +2118,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
kv_next, dk_dv = jnp.unstack(kv_dk_dv)
dq = dq + dq_per_step
dk_dv = dk_dv + dk_dv_per_step
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
dbias = dbias + dbias_per_step
return (kv_next, dq, dk_dv, dbias)
......@@ -2124,7 +2135,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
dk_dv = helper.permute_kv(dk_dv, cp_perm)
global_dbias = dbias
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv)
......@@ -2136,6 +2147,271 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedRingAttnBwdPrimitive)
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Striped Ring Attention Forward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def fwd_impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
):
if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
raise ValueError("THD + ring attn only supports passing seqment_ids/pos")
_not_used = jnp.zeros(0, dtype=v.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
if not config.qkv_layout.is_qkvpacked():
subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
else:
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
batch, q_max_seqlen, head, _ = q.shape
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)
# RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently.
rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry):
kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry
# TODO(rewang): To check whether we need special handle for the last idx
# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q,
kv,
_not_used,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
subblock_config,
)
# TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
return output_per_step.astype(jnp.float32), softmax_aux_per_step
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
output - output_per_step
)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
return new_out, new_aux
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
idx == 0,
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)
return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)
carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(_, _, _, output, softmax_aux) = carry
softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))
return output.astype(q.dtype), softmax_aux, rng_state
return mesh, fwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnStripedFwdPrimitive)
class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Striped Ring Attention Backward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
arg_shardings = tuple(arg.sharding for arg in arg_infos)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
def bwd_impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
):
if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
raise ValueError("THD + ring attn only supports passing seqment_ids/pos")
_not_used = jnp.zeros(0, dtype=output.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
if not config.qkv_layout.is_qkvpacked():
subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
else:
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q)
dkv = jnp.zeros_like(kv)
dbias = jnp.zeros_like(bias)
def scan_kv_block(_idx, carry):
kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry
# Start communication that feeds the next iteration.
# We further combine the tensors to improve overlap.
kv_dkv = jnp.stack([kv, dkv])
kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute():
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
config=subblock_config,
)
return dq_per_step, dkv_per_step, dbias_per_step
dq_per_step, dkv_per_step, dbias_per_step = compute()
kv_next, dkv = jnp.unstack(kv_dkv)
dq += dq_per_step
dkv += dkv_per_step
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
dbias = dbias + dbias_per_step
return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)
carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for idx in range(cp_size):
carry = scan_kv_block(idx, carry)
(_, _, _, dq, dkv, dbias) = carry
# Final permute to put gradients back to their final resting place.
dkv = helper.permute_kv(dkv, cp_perm)
global_dbias = dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dkv)
return dq, dk, dv, global_dbias
return mesh, bwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnStripedBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis:
gmr = global_mesh_resource()
......@@ -2151,9 +2427,9 @@ def fused_attn_fwd(
bias: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -2184,9 +2460,9 @@ def fused_attn_fwd(
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
......@@ -2205,22 +2481,23 @@ def fused_attn_fwd(
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
if qkv_layout.is_qkvpacked():
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
elif qkv_layout.is_kvpacked():
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
elif qkv_layout.is_separate():
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
else:
raise ValueError(f"Unknown {qkv_layout=}")
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
......@@ -2242,6 +2519,10 @@ def fused_attn_fwd(
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING:
# We must use stripe attention for THD-RING
if qkv_layout.is_thd():
primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
else:
primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
......@@ -2262,9 +2543,9 @@ def fused_attn_bwd(
output: jnp.ndarray,
doutput: jnp.ndarray,
sequence_descriptor: SequenceDescriptor,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
......@@ -2296,9 +2577,9 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
......@@ -2319,22 +2600,23 @@ def fused_attn_bwd(
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
if qkv_layout.is_qkvpacked():
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
elif qkv_layout.is_kvpacked():
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
elif qkv_layout.is_separate():
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
else:
raise ValueError(f"Unknown {qkv_layout=}")
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
......@@ -2356,10 +2638,12 @@ def fused_attn_bwd(
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING:
if qkv_layout.is_thd():
primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
else:
primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primitive.bind(
*qkv_for_primitive,
bias,
......
......@@ -229,6 +229,10 @@ static void FusedAttnForwardImpl(
if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
// Memset to 0xF0 for filling large negative numbers
auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads;
cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream);
}
/* Output tensors */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment