Unverified Commit c5d6a069 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] THD ring attention (#1454)



* Support THD + ring attention for self attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Consolidate reorder strategy
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix dataclass frozen issue
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extension
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine P2P helper check_supported
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add segment_ids/pos check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fixup
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dual chunk swap example
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Align different reorder code structure
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4b523d29
...@@ -23,6 +23,7 @@ from transformer_engine.jax.attention import ( ...@@ -23,6 +23,7 @@ from transformer_engine.jax.attention import (
reorder_causal_load_balancing, reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing, inverse_reorder_causal_load_balancing,
CPStrategy, CPStrategy,
ReorderStrategy,
) )
...@@ -210,29 +211,29 @@ class TestDistributedCrossAttn: ...@@ -210,29 +211,29 @@ class TestDistributedCrossAttn:
"data_shape", "data_shape",
[ [
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes. # Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"), pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
], ],
) )
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) @pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attn_mask_type", "qkv_layout, attn_mask_type",
[ [
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
], pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
) pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
@pytest.mark.parametrize("dtype", [jnp.bfloat16]) pytest.param(
@pytest.mark.parametrize( QKVLayout.THD_THD_THD,
"qkv_layout", AttnMaskType.PADDING_CAUSAL_MASK,
[ id="THD_SEPARATE-PADDING_CAUSAL",
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), ),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"load_balanced", "load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
) )
class TestDistributedContextParallelSelfAttn: class TestDistributedContextParallelSelfAttn:
...@@ -265,7 +266,6 @@ class TestDistributedContextParallelSelfAttn: ...@@ -265,7 +266,6 @@ class TestDistributedContextParallelSelfAttn:
data_shape = batch, seqlen, num_head, hidden data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
runner = FusedAttnRunner( runner = FusedAttnRunner(
batch, batch,
...@@ -282,7 +282,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -282,7 +282,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
bias_shape, bias_shape,
None, None,
SeqDescFormat.Seqlens, SeqDescFormat.SegmentIDs,
number_of_devices=device_count, number_of_devices=device_count,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
mesh_axes=mesh_axes, mesh_axes=mesh_axes,
...@@ -297,7 +297,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -297,7 +297,7 @@ class TestDistributedContextParallelSelfAttn:
dtype, dtype,
qkv_layout, qkv_layout,
attn_bias_type, attn_bias_type,
attn_mask_type, mask_type,
dropout_prob, dropout_prob,
num_head, num_head,
num_kv_heads, num_kv_heads,
...@@ -340,6 +340,8 @@ class TestDistributedContextParallelSelfAttn: ...@@ -340,6 +340,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout, qkv_layout,
load_balanced, load_balanced,
): ):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn( return self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
...@@ -377,7 +379,10 @@ class TestDistributedContextParallelSelfAttn: ...@@ -377,7 +379,10 @@ class TestDistributedContextParallelSelfAttn:
else: else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
self.impl_test_context_parallel_attn( if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
return self.impl_test_context_parallel_attn(
device_count, device_count,
mesh_shape, mesh_shape,
mesh_axes, mesh_axes,
...@@ -404,17 +409,26 @@ class TestReorderCausalLoadBalancing: ...@@ -404,17 +409,26 @@ class TestReorderCausalLoadBalancing:
], ],
) )
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format): @pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD: if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1) tensor = tensor.swapaxes(0, 1)
seq_dim = 0
ref = tensor.copy() ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2]) reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2]) inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reordered = reorder(tensor, cp_size, qkv_format) reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, cp_size, qkv_format) inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
assert jnp.array_equal(inversed, ref) assert jnp.array_equal(inversed, ref)
...@@ -28,12 +28,14 @@ from transformer_engine.jax.attention import ( ...@@ -28,12 +28,14 @@ from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout, QKVLayout,
QKVFormat,
reorder_causal_load_balancing, reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing, inverse_reorder_causal_load_balancing,
fused_attn, fused_attn,
make_swa_mask, make_swa_mask,
SequenceDescriptor, SequenceDescriptor,
CPStrategy, CPStrategy,
ReorderStrategy,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import ( from transformer_engine.transformer_engine_jax import (
...@@ -347,9 +349,9 @@ class FusedAttnRunner: ...@@ -347,9 +349,9 @@ class FusedAttnRunner:
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.dtype, self.dtype,
self.dtype, self.dtype,
self.qkv_layout.value, self.qkv_layout,
self.attn_bias_type.value, self.attn_bias_type,
self.attn_mask_type.value, self.attn_mask_type,
self.dropout_prob, self.dropout_prob,
self.num_heads_q, self.num_heads_q,
self.num_heads_kv, self.num_heads_kv,
...@@ -500,7 +502,8 @@ class FusedAttnRunner: ...@@ -500,7 +502,8 @@ class FusedAttnRunner:
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
) )
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD: # TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q self.pad_kv = self.pad_q
...@@ -536,6 +539,30 @@ class FusedAttnRunner: ...@@ -536,6 +539,30 @@ class FusedAttnRunner:
self.window_size, self.window_size,
) )
if self.cp_size > 1 and self.cp_load_balanced:
if self.qkv_layout.is_thd():
reorder_strategy = ReorderStrategy.Striped
else:
reorder_strategy = ReorderStrategy.DualChunkSwap
seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
# Test different input formats # Test different input formats
if self.qkv_layout.is_thd(): if self.qkv_layout.is_thd():
match self.seq_desc_format: match self.seq_desc_format:
...@@ -548,8 +575,14 @@ class FusedAttnRunner: ...@@ -548,8 +575,14 @@ class FusedAttnRunner:
) )
case SeqDescFormat.SegmentIDs: case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv), (
(self.segment_pos_q, self.segment_pos_kv), self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
) )
case _: case _:
raise ValueError(f"Unknown {self.seq_desc_format=}") raise ValueError(f"Unknown {self.seq_desc_format=}")
...@@ -605,7 +638,12 @@ class FusedAttnRunner: ...@@ -605,7 +638,12 @@ class FusedAttnRunner:
case _: case _:
def to_dp_shardings(x): def to_dp_shardings(x):
pspec = PartitionSpec(self.mesh_resource.dp_resource) if x.ndim == 1:
pspec = PartitionSpec(self.mesh_resource.dp_resource)
else:
pspec = PartitionSpec(
self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
)
return NamedSharding(self.mesh, pspec) return NamedSharding(self.mesh, pspec)
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
...@@ -637,24 +675,6 @@ class FusedAttnRunner: ...@@ -637,24 +675,6 @@ class FusedAttnRunner:
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
# Softmax aux sharding
if self.cp_size > 1 and self.cp_load_balanced:
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
def test_forward(self): def test_forward(self):
""" """
Test forward without JIT Test forward without JIT
...@@ -733,15 +753,24 @@ class FusedAttnRunner: ...@@ -733,15 +753,24 @@ class FusedAttnRunner:
self._setup_inputs() self._setup_inputs()
def grad_func(func, *args, **kwargs): def grad_func(func, *args, cp_reverse_out=False, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient # Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if self.attn_mask_type.is_causal(): if self.attn_mask_type.is_causal():
gradient_multiplier /= 10 gradient_multiplier /= 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
ret_valid = jnp.where( if not cp_reverse_out:
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs) ret_valid = jnp.where(
) self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
func(*args, **kwargs),
)
else:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
self.cp_inverse_reorder_fn(func(*args, **kwargs)),
)
return ( return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype) ).astype(self.dtype)
...@@ -787,7 +816,7 @@ class FusedAttnRunner: ...@@ -787,7 +816,7 @@ class FusedAttnRunner:
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func( lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, **kwargs customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
), ),
arg_nums, arg_nums,
), ),
......
...@@ -135,6 +135,39 @@ class QKVLayout(Enum): ...@@ -135,6 +135,39 @@ class QKVLayout(Enum):
""" """
return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
def to_qkvpacked(self):
"""
Return the corresponding qkvpacked format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BS3HD
if qkv_format == QKVFormat.THD:
return QKVLayout.T3HD
raise ValueError(f"Unsupported {qkv_format=}")
def to_kvpacked(self):
"""
Return the corresponding kvpacked format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BSHD_BS2HD
if qkv_format == QKVFormat.THD:
return QKVLayout.THD_T2HD
raise ValueError(f"Unsupported {qkv_format=}")
def to_separate(self):
"""
Return the corresponding separate format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BSHD_BSHD_BSHD
if qkv_format == QKVFormat.THD:
return QKVLayout.THD_THD_THD
raise ValueError(f"Unsupported {qkv_format=}")
class CPStrategy(Enum): class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention. """Defines the context parallel strategies of Jax fused attention.
...@@ -149,6 +182,28 @@ class CPStrategy(Enum): ...@@ -149,6 +182,28 @@ class CPStrategy(Enum):
RING = 2 RING = 2
class ReorderStrategy(Enum):
"""
Defines the tokens re-order strategy for context parallel load balancing for causal mask.
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
mulitple of 2 * cp_size.
Examples:
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
- Striped: This strategy distributes the tokens in a striped (interleaved) manner across
the sequence. This is currently used for THD load balance.
Example: Consider 4 GPUs with seqlens=16.
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15]
- After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15]
"""
DualChunkSwap = 0
Striped = 1
def make_swa_mask( def make_swa_mask(
segment_pos_q: jnp.ndarray, segment_pos_q: jnp.ndarray,
segment_pos_kv: jnp.ndarray, segment_pos_kv: jnp.ndarray,
...@@ -243,9 +298,9 @@ def is_fused_attn_kernel_available( ...@@ -243,9 +298,9 @@ def is_fused_attn_kernel_available(
return tex.FusedAttnHelper( return tex.FusedAttnHelper(
q_dtype, q_dtype,
kv_dtype, kv_dtype,
qkv_layout.value, qkv_layout,
attn_bias_type.value, attn_bias_type,
attn_mask_type.value, attn_mask_type,
dropout_probability, dropout_probability,
q_num_heads, q_num_heads,
kv_num_heads, kv_num_heads,
...@@ -276,16 +331,24 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout): ...@@ -276,16 +331,24 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen return batch, q_max_seqlen, kv_max_seqlen
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
"""Reorders a tensor for load balancing the compute of causal attention.""" """Reorders a tensor for load balancing the compute of causal attention."""
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False) return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
raise ValueError(f"Unsupported {strategy=}")
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
):
"""Inverse operation of `reorder_causal_load_balancing`.""" """Inverse operation of `reorder_causal_load_balancing`."""
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
raise ValueError(f"Unsupported {strategy=}")
def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq): def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
...@@ -412,8 +475,6 @@ class SequenceDescriptor: ...@@ -412,8 +475,6 @@ class SequenceDescriptor:
""" """
Acquire the seqlens/offsets for cuDNN backend Acquire the seqlens/offsets for cuDNN backend
""" """
attn_mask_type = AttnMaskType(attn_mask_type)
qkv_layout = QKVLayout(qkv_layout)
q_segment_ids, kv_segment_ids = self.segment_ids q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos q_segment_pos, kv_segment_pos = self.segment_pos
assert q_segment_ids.shape == q_segment_pos.shape assert q_segment_ids.shape == q_segment_pos.shape
...@@ -589,9 +650,9 @@ def _legacy_fused_attn( ...@@ -589,9 +650,9 @@ def _legacy_fused_attn(
Intra-sequence padding is not valid. The padded tokens can only on the right-most. Intra-sequence padding is not valid. The padded tokens can only on the right-most.
Otherwise the results will be wrong. Otherwise the results will be wrong.
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
...@@ -608,16 +669,18 @@ def _legacy_fused_attn( ...@@ -608,16 +669,18 @@ def _legacy_fused_attn(
# Check inputs qkv # Check inputs qkv
match qkv_layout: match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD: case QKVLayout.BS3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: case QKVLayout.BSHD_BS2HD:
assert ( assert (
len(qkv) == 2 len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: case QKVLayout.BSHD_BSHD_BSHD:
assert ( assert (
len(qkv) == 3 len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
case _:
raise ValueError(f"Unknown {qkv_layout=}")
# convert the mask to seqlens, mask doesn't support ragged offsets # convert the mask to seqlens, mask doesn't support ragged offsets
if not attn_mask_type.is_padding(): if not attn_mask_type.is_padding():
...@@ -689,16 +752,18 @@ def fused_attn_thd( ...@@ -689,16 +752,18 @@ def fused_attn_thd(
# Check inputs qkv # Check inputs qkv
match qkv_layout: match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD: case QKVLayout.T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_T2HD: case QKVLayout.THD_T2HD:
assert ( assert (
len(qkv) == 2 len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_THD_THD: case QKVLayout.THD_THD_THD:
assert ( assert (
len(qkv) == 3 len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
case _:
raise ValueError(f"Unknown {qkv_layout=}")
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
assert q_seq_lens.shape == (batch, q_max_seqlen) assert q_seq_lens.shape == (batch, q_max_seqlen)
...@@ -789,9 +854,9 @@ def _fused_attn_fwd_rule( ...@@ -789,9 +854,9 @@ def _fused_attn_fwd_rule(
bias, bias,
sequence_descriptor, sequence_descriptor,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout.value, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -845,9 +910,9 @@ def _fused_attn_bwd_rule( ...@@ -845,9 +910,9 @@ def _fused_attn_bwd_rule(
output, output,
dz, dz,
sequence_descriptor, sequence_descriptor,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout.value, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
...@@ -903,9 +968,9 @@ def fused_attn( ...@@ -903,9 +968,9 @@ def fused_attn(
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence. sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence.
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for attention""" """JAX/TE custom ops for attention"""
from dataclasses import dataclass from dataclasses import dataclass, replace
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import os import os
...@@ -17,17 +17,18 @@ from jax.interpreters.mlir import ir ...@@ -17,17 +17,18 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi from jax import ffi
from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
CPStrategy,
SequenceDescriptor,
)
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import ( from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Fused_Attn_Backend,
nvte_get_qkv_format,
)
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import ( from .misc import (
...@@ -79,9 +80,9 @@ class _FusedAttnConfig: ...@@ -79,9 +80,9 @@ class _FusedAttnConfig:
Passes static configuration properties of fused attention. Passes static configuration properties of fused attention.
""" """
attn_bias_type: NVTE_Bias_Type attn_bias_type: AttnBiasType
attn_mask_type: NVTE_Mask_Type attn_mask_type: AttnMaskType
qkv_layout: NVTE_QKV_Layout qkv_layout: QKVLayout
scaling_factor: float scaling_factor: float
dropout_probability: float dropout_probability: float
is_training: bool is_training: bool
...@@ -99,9 +100,9 @@ class FusedAttnHelper: ...@@ -99,9 +100,9 @@ class FusedAttnHelper:
q_dtype: jnp.dtype q_dtype: jnp.dtype
kv_dtype: jnp.dtype kv_dtype: jnp.dtype
qkv_layout: NVTE_QKV_Layout qkv_layout: QKVLayout
attn_bias_type: NVTE_Bias_Type attn_bias_type: AttnBiasType
attn_mask_type: NVTE_Mask_Type attn_mask_type: AttnMaskType
dropout_probability: float dropout_probability: float
q_num_heads: int q_num_heads: int
kv_num_heads: int kv_num_heads: int
...@@ -119,9 +120,9 @@ class FusedAttnHelper: ...@@ -119,9 +120,9 @@ class FusedAttnHelper:
return transformer_engine_jax.get_fused_attn_backend( return transformer_engine_jax.get_fused_attn_backend(
jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype), jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout, self.qkv_layout.value,
self.attn_bias_type, self.attn_bias_type.value,
self.attn_mask_type, self.attn_mask_type.value,
self.dropout_probability, self.dropout_probability,
self.q_num_heads, self.q_num_heads,
self.kv_num_heads, self.kv_num_heads,
...@@ -140,24 +141,25 @@ class FusedAttnHelper: ...@@ -140,24 +141,25 @@ class FusedAttnHelper:
@staticmethod @staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval""" """Parse qkv aval"""
match qkv_layout: if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: raise NotImplementedError
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape if qkv_layout.is_qkvpacked():
kv_batch_shape = q_batch_shape *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_max_seqlen = q_max_seqlen kv_batch_shape = q_batch_shape
num_gqa_groups = attn_heads kv_max_seqlen = q_max_seqlen
kv_head_dim = q_head_dim num_gqa_groups = attn_heads
assert nqkv == 3 kv_head_dim = q_head_dim
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: assert nqkv == 3
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape elif qkv_layout.is_kvpacked():
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
assert nkv == 2 *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: assert nkv == 2
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape elif qkv_layout.is_separate():
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
assert k_aval.shape == v_aval.shape *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
case _: assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}"
raise ValueError(f"Unexpected {qkv_layout=}") else:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim assert q_head_dim == kv_head_dim
assert q_aval.dtype == k_aval.dtype == v_aval.dtype assert q_aval.dtype == k_aval.dtype == v_aval.dtype
...@@ -310,7 +312,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -310,7 +312,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
else: else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
...@@ -330,9 +332,9 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -330,9 +332,9 @@ class FusedAttnFwdPrimitive(BasePrimitive):
head_dim, head_dim,
config.scaling_factor, config.scaling_factor,
config.dropout_probability, config.dropout_probability,
config.attn_bias_type, config.attn_bias_type.value,
config.attn_mask_type, config.attn_mask_type.value,
config.qkv_layout, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
config.max_segments_per_seq, config.max_segments_per_seq,
...@@ -385,7 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -385,7 +387,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
input_batch = reduce(operator.mul, batch_shape) input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
else: else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
...@@ -419,9 +421,9 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -419,9 +421,9 @@ class FusedAttnFwdPrimitive(BasePrimitive):
max_segments_per_seq=config.max_segments_per_seq, max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor), scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability), dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type), bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type), mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout), qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training, is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
...@@ -511,7 +513,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -511,7 +513,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
) )
) )
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1): def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape x_shape = x.shape
...@@ -529,20 +531,11 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -529,20 +531,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
) )
return offsets_2d return offsets_2d
match config.qkv_layout: batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
case NVTE_QKV_Layout.NVTE_T3HD: q, k, v, config.qkv_layout
kv_max_seqlen = q_max_seqlen = q.shape[-4] )
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
case NVTE_QKV_Layout.NVTE_THD_T2HD: kv_batch = q_batch = batch[0]
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0 # Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0: # cuDNN version < 9.3.0:
...@@ -610,29 +603,28 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -610,29 +603,28 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del result_infos del result_infos
q_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
match config.qkv_layout: if config.qkv_layout.is_qkvpacked():
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: # q_spec = (...batch, q_seqlen, 3, head, hidden)
# q_spec = (...batch, q_seqlen, head, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) softmax_aux_sharding = NamedSharding(
softmax_aux_sharding = NamedSharding( mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) )
) elif config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: # q_spec = (...batch, q_seqlen, head, hidden)
# q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding(
softmax_aux_sharding = NamedSharding( mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) )
) elif config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: # q_spec = (...batch, q_seqlen, head, hidden)
# q_spec = (...batch, q_seqlen, head, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) softmax_aux_sharding = NamedSharding(
softmax_aux_sharding = NamedSharding( mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) )
) else:
case _: raise ValueError(f"Unsupported {config.qkv_layout=}")
raise ValueError(f"Unsupported {config.qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding) return (out_sharding, softmax_aux_sharding, rng_state_sharding)
...@@ -705,7 +697,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -705,7 +697,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
) )
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
else: else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
...@@ -725,9 +717,9 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -725,9 +717,9 @@ class FusedAttnBwdPrimitive(BasePrimitive):
head_dim, head_dim,
config.scaling_factor, config.scaling_factor,
config.dropout_probability, config.dropout_probability,
config.attn_bias_type, config.attn_bias_type.value,
config.attn_mask_type, config.attn_mask_type.value,
config.qkv_layout, config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training, config.is_training,
deterministic, deterministic,
...@@ -787,7 +779,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -787,7 +779,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
input_batch = reduce(operator.mul, batch_shape) input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
else: else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
...@@ -824,9 +816,9 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -824,9 +816,9 @@ class FusedAttnBwdPrimitive(BasePrimitive):
max_segments_per_seq=config.max_segments_per_seq, max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor), scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability), dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type), bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type), mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout), qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training, is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0], window_size_left=config.window_size[0],
...@@ -922,7 +914,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -922,7 +914,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
) )
) )
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1): def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape x_shape = x.shape
...@@ -941,20 +933,11 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -941,20 +933,11 @@ class FusedAttnBwdPrimitive(BasePrimitive):
) )
return offsets_2d return offsets_2d
match config.qkv_layout: batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
case NVTE_QKV_Layout.NVTE_T3HD: q, k, v, config.qkv_layout
kv_max_seqlen = q_max_seqlen = q.shape[-4] )
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) assert len(batch) == 1
case NVTE_QKV_Layout.NVTE_THD_T2HD: kv_batch = q_batch = batch[0]
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0 # Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0: # cuDNN version < 9.3.0:
...@@ -1088,7 +1071,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1088,7 +1071,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config=config, config=config,
) )
global_dbias = local_dbias global_dbias = local_dbias
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias return local_dq, local_dk, local_dv, global_dbias
...@@ -1098,7 +1081,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1098,7 +1081,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive) register_primitive(FusedAttnBwdPrimitive)
def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention.""" """Reorders a tensor for load balancing the compute of causal attention."""
if cp_size == 1: if cp_size == 1:
return tensor return tensor
...@@ -1108,7 +1091,7 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu ...@@ -1108,7 +1091,7 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu
# Need to ensure we have 2 pairs to swap for balancing between cp ranks # Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0: if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
...@@ -1150,6 +1133,33 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu ...@@ -1150,6 +1133,33 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu
return combined.reshape(ori_tensor_shape) return combined.reshape(ori_tensor_shape)
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
"""Reorders a tensor for load balancing with striped pattern"""
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
)
if not is_inverse:
new_shape = [
*origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size],
*origin_shape[seq_dim + 1 :],
]
else:
new_shape = [
*origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size],
*origin_shape[seq_dim + 1 :],
]
chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)
@dataclass(frozen=True) @dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper: class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention.""" """Helper class to assist with running the all-gather strategy for CP attention."""
...@@ -1161,17 +1171,17 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1161,17 +1171,17 @@ class _FusedAttnCPWithAllGatherHelper:
"""Checks if the context parallel implementation is supported by the given arguments.""" """Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention" header = "Context parallel fused attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts: if self.config.qkv_layout not in allowed_layouts:
raise ValueError( raise ValueError(
f"{header} only supports layouts:" f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
) )
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks: if self.config.attn_mask_type not in allowed_masks:
raise ValueError( raise ValueError(
f"{header} only supports masking types: " f"{header} only supports masking types: "
...@@ -1189,8 +1199,8 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1189,8 +1199,8 @@ class _FusedAttnCPWithAllGatherHelper:
def get_adjusted_mask(self): def get_adjusted_mask(self):
"""Converts the mask for context parallelism.""" """Converts the mask for context parallelism."""
if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type return self.config.attn_mask_type
def get_step_config(self) -> _FusedAttnConfig: def get_step_config(self) -> _FusedAttnConfig:
...@@ -1217,14 +1227,13 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1217,14 +1227,13 @@ class _FusedAttnCPWithAllGatherHelper:
) )
if self.config.context_parallel_load_balanced: if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True) x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
return x return x
match self.config.qkv_layout: if self.config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return ag(k), v
return ag(k), v if self.config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return ag(k), ag(v)
return ag(k), ag(v)
return k, v # fall through return k, v # fall through
...@@ -1234,7 +1243,7 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1234,7 +1243,7 @@ class _FusedAttnCPWithAllGatherHelper:
def rs(x): def rs(x):
if self.config.context_parallel_load_balanced: if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False) x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
return lax_paral_op( return lax_paral_op(
x, x,
...@@ -1245,11 +1254,10 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1245,11 +1254,10 @@ class _FusedAttnCPWithAllGatherHelper:
tiled=True, tiled=True,
) )
match self.config.qkv_layout: if self.config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return rs(dk), dv
return rs(dk), dv if self.config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return rs(dk), rs(dv)
return rs(dk), rs(dv)
return dk, dv # fall through return dk, dv # fall through
...@@ -1286,11 +1294,10 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1286,11 +1294,10 @@ class _FusedAttnCPWithAllGatherHelper:
def sliced(x): def sliced(x):
return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)
match self.config.qkv_layout: if self.config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return sliced(k), v
return sliced(k), v if self.config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return sliced(k), sliced(v)
return sliced(k), sliced(v)
return k, v # fall through return k, v # fall through
...@@ -1300,13 +1307,12 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1300,13 +1307,12 @@ class _FusedAttnCPWithAllGatherHelper:
def pad(x, npad): def pad(x, npad):
return jnp.pad(x, npad, "constant", constant_values=0.0) return jnp.pad(x, npad, "constant", constant_values=0.0)
match self.config.qkv_layout: if self.config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] return pad(dk, npad), dv
return pad(dk, npad), dv if self.config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] return pad(dk, npad), pad(dv, npad)
return pad(dk, npad), pad(dv, npad)
return dk, dv # fall through return dk, dv # fall through
...@@ -1378,7 +1384,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1378,7 +1384,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
results = [] results = []
for sub_idx in range(2): for sub_idx in range(2):
if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: if config.attn_mask_type == AttnMaskType.NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked k_unmasked, v_unmasked = k, v # full kv used for unmasked
else: else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
...@@ -1514,7 +1520,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1514,7 +1520,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
results = [] results = []
for sub_idx in range(2): for sub_idx in range(2):
if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: if config.attn_mask_type == AttnMaskType.NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked k_unmasked, v_unmasked = k, v # full kv used for unmasked
else: else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
...@@ -1544,7 +1550,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -1544,7 +1550,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
) )
# pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK: if config.attn_mask_type != AttnMaskType.NO_MASK:
pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)
...@@ -1614,24 +1620,31 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1614,24 +1620,31 @@ class _FusedAttnCPWithP2PHelper:
"""Checks if the context parallel implementation is supported by the given arguments.""" """Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused ring attention" header = "Context parallel fused ring attention"
allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] if self.config.qkv_layout.is_thd():
allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
else:
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts: if self.config.qkv_layout not in allowed_layouts:
raise ValueError( raise ValueError(
f"{header} only supports layouts:" f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
) )
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] if self.config.qkv_layout.is_thd():
allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
else:
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks: if self.config.attn_mask_type not in allowed_masks:
raise ValueError( raise ValueError(
f"{header} only supports masking types: " f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
) )
if self.config.max_segments_per_seq != 1: if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
raise ValueError( raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:" f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}" f" {self.config.max_segments_per_seq}"
...@@ -1655,7 +1668,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1655,7 +1668,7 @@ class _FusedAttnCPWithP2PHelper:
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability, dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training, is_training=self.config.is_training,
...@@ -1668,21 +1681,19 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1668,21 +1681,19 @@ class _FusedAttnCPWithP2PHelper:
def stack_kv(self, k, v): def stack_kv(self, k, v):
"""Stacks k and v tensors if not stacked.""" """Stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=k.dtype) _not_used = jnp.zeros(0, dtype=k.dtype)
match self.config.qkv_layout: if self.config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return k
return k if self.config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return jnp.stack([k, v], axis=2)
return jnp.stack([k, v], axis=2)
return _not_used return _not_used
def unstack_kv(self, kv): def unstack_kv(self, kv):
"""Un-stacks k and v tensors if not stacked.""" """Un-stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=kv.dtype) _not_used = jnp.zeros(0, dtype=kv.dtype)
match self.config.qkv_layout: if self.config.qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: return kv, _not_used
return kv, _not_used if self.config.qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: return jnp.unstack(kv, axis=2)
return jnp.unstack(kv, axis=2)
return _not_used, _not_used # fall through return _not_used, _not_used # fall through
def permute_kv(self, kv, cp_perm): def permute_kv(self, kv, cp_perm):
...@@ -1803,8 +1814,8 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1803,8 +1814,8 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
) )
return output_per_step, softmax_aux_per_step return output_per_step, softmax_aux_per_step
causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
def half_kv_no_mask_compute(): def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
...@@ -1824,7 +1835,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1824,7 +1835,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
_kv_segment_ids, _kv_segment_ids,
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), config=helper.get_step_config(AttnMaskType.NO_MASK),
) )
return output_per_step, softmax_aux_per_step return output_per_step, softmax_aux_per_step
...@@ -1846,7 +1857,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1846,7 +1857,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
_kv_segment_ids, _kv_segment_ids,
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), config=helper.get_step_config(AttnMaskType.NO_MASK),
) )
output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
softmax_aux_per_step = jnp.concat( softmax_aux_per_step = jnp.concat(
...@@ -1865,7 +1876,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): ...@@ -1865,7 +1876,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
) )
return output_per_step, softmax_aux_per_step return output_per_step, softmax_aux_per_step
if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
# This is for nested jax.lax.cond # This is for nested jax.lax.cond
def jax_cond_wrap(): def jax_cond_wrap():
if config.context_parallel_load_balanced: if config.context_parallel_load_balanced:
...@@ -2019,8 +2030,8 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2019,8 +2030,8 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
) )
return dq_per_step, dk_dv_per_step, dbias_per_step return dq_per_step, dk_dv_per_step, dbias_per_step
causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
def half_kv_no_mask_compute(): def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
...@@ -2043,7 +2054,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2043,7 +2054,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
_kv_segment_ids, _kv_segment_ids,
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), config=helper.get_step_config(AttnMaskType.NO_MASK),
) )
dk_dv_per_step = jnp.concat( dk_dv_per_step = jnp.concat(
[dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1 [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
...@@ -2081,7 +2092,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2081,7 +2092,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
_kv_segment_ids, _kv_segment_ids,
_q_segment_pos, _q_segment_pos,
_kv_segment_pos, _kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), config=helper.get_step_config(AttnMaskType.NO_MASK),
) )
dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
return dq_per_step, dk_dv_per_step, dbias_per_step return dq_per_step, dk_dv_per_step, dbias_per_step
...@@ -2089,7 +2100,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2089,7 +2100,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def skip_compute(): def skip_compute():
return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias) return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)
if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
# This is for nested jax.lax.cond # This is for nested jax.lax.cond
def jax_cond_wrap(): def jax_cond_wrap():
if config.context_parallel_load_balanced: if config.context_parallel_load_balanced:
...@@ -2107,7 +2118,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2107,7 +2118,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
kv_next, dk_dv = jnp.unstack(kv_dk_dv) kv_next, dk_dv = jnp.unstack(kv_dk_dv)
dq = dq + dq_per_step dq = dq + dq_per_step
dk_dv = dk_dv + dk_dv_per_step dk_dv = dk_dv + dk_dv_per_step
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type is not AttnBiasType.NO_BIAS:
dbias = dbias + dbias_per_step dbias = dbias + dbias_per_step
return (kv_next, dq, dk_dv, dbias) return (kv_next, dq, dk_dv, dbias)
...@@ -2124,7 +2135,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2124,7 +2135,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
dk_dv = helper.permute_kv(dk_dv, cp_perm) dk_dv = helper.permute_kv(dk_dv, cp_perm)
global_dbias = dbias global_dbias = dbias
if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv) dk, dv = helper.unstack_kv(dk_dv)
...@@ -2136,6 +2147,271 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): ...@@ -2136,6 +2147,271 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
register_primitive(FusedRingAttnBwdPrimitive) register_primitive(FusedRingAttnBwdPrimitive)
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Striped Ring Attention Forward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def fwd_impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
):
if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
raise ValueError("THD + ring attn only supports passing seqment_ids/pos")
_not_used = jnp.zeros(0, dtype=v.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
if not config.qkv_layout.is_qkvpacked():
subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
else:
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
batch, q_max_seqlen, head, _ = q.shape
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)
# RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently.
rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry):
kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry
# TODO(rewang): To check whether we need special handle for the last idx
# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q,
kv,
_not_used,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
subblock_config,
)
# TODO(rewang): THD softmax_aux layout is acutally [B, S, H]
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
return output_per_step.astype(jnp.float32), softmax_aux_per_step
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
output - output_per_step
)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
return new_out, new_aux
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
idx == 0,
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)
return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)
carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(_, _, _, output, softmax_aux) = carry
softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1))
return output.astype(q.dtype), softmax_aux, rng_state
return mesh, fwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnStripedFwdPrimitive)
class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Striped Ring Attention Backward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
arg_shardings = tuple(arg.sharding for arg in arg_infos)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
def bwd_impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
):
if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
raise ValueError("THD + ring attn only supports passing seqment_ids/pos")
_not_used = jnp.zeros(0, dtype=output.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
if not config.qkv_layout.is_qkvpacked():
subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
else:
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q)
dkv = jnp.zeros_like(kv)
dbias = jnp.zeros_like(bias)
def scan_kv_block(_idx, carry):
kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry
# Start communication that feeds the next iteration.
# We further combine the tensors to improve overlap.
kv_dkv = jnp.stack([kv, dkv])
kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute():
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
config=subblock_config,
)
return dq_per_step, dkv_per_step, dbias_per_step
dq_per_step, dkv_per_step, dbias_per_step = compute()
kv_next, dkv = jnp.unstack(kv_dkv)
dq += dq_per_step
dkv += dkv_per_step
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
dbias = dbias + dbias_per_step
return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)
carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for idx in range(cp_size):
carry = scan_kv_block(idx, carry)
(_, _, _, dq, dkv, dbias) = carry
# Final permute to put gradients back to their final resting place.
dkv = helper.permute_kv(dkv, cp_perm)
global_dbias = dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dkv)
return dq, dk, dv, global_dbias
return mesh, bwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnStripedBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str): def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis: if not cp_axis:
gmr = global_mesh_resource() gmr = global_mesh_resource()
...@@ -2151,9 +2427,9 @@ def fused_attn_fwd( ...@@ -2151,9 +2427,9 @@ def fused_attn_fwd(
bias: Optional[jnp.ndarray], bias: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray], seed: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type, attn_bias_type: AttnBiasType,
attn_mask_type: NVTE_Mask_Type, attn_mask_type: AttnMaskType,
qkv_layout: NVTE_QKV_Layout, qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -2184,9 +2460,9 @@ def fused_attn_fwd( ...@@ -2184,9 +2460,9 @@ def fused_attn_fwd(
kv_seq_offsets (jnp.ndarray): kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,]. The offsets in the sequence dim for the query, with shape [batch + 1,].
seed (Optional[jnp.ndarray]): Optional random seed for dropout. seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
...@@ -2205,22 +2481,23 @@ def fused_attn_fwd( ...@@ -2205,22 +2481,23 @@ def fused_attn_fwd(
# For optional tensors, which custom calls doesn't support None # For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype) _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout: if qkv_layout.is_qkvpacked():
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" qkv_for_primitive = [*qkv, _not_used, _not_used]
qkv_for_primitive = [*qkv, _not_used, _not_used] elif qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: assert (
assert ( len(qkv) == 2
len(qkv) == 2 ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" qkv_for_primitive = [*qkv, _not_used]
qkv_for_primitive = [*qkv, _not_used] elif qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: assert (
assert ( len(qkv) == 3
len(qkv) == 3 ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" qkv_for_primitive = qkv
qkv_for_primitive = qkv else:
raise ValueError(f"Unknown {qkv_layout=}")
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
...@@ -2242,7 +2519,11 @@ def fused_attn_fwd( ...@@ -2242,7 +2519,11 @@ def fused_attn_fwd(
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING: case CPStrategy.RING:
primitive = FusedRingAttnFwdPrimitive.outer_primitive # We must use stripe attention for THD-RING
if qkv_layout.is_thd():
primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
else:
primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
return primitive.bind( return primitive.bind(
...@@ -2262,9 +2543,9 @@ def fused_attn_bwd( ...@@ -2262,9 +2543,9 @@ def fused_attn_bwd(
output: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, doutput: jnp.ndarray,
sequence_descriptor: SequenceDescriptor, sequence_descriptor: SequenceDescriptor,
attn_bias_type: NVTE_Bias_Type, attn_bias_type: AttnBiasType,
attn_mask_type: NVTE_Mask_Type, attn_mask_type: AttnMaskType,
qkv_layout: NVTE_QKV_Layout, qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
...@@ -2296,9 +2577,9 @@ def fused_attn_bwd( ...@@ -2296,9 +2577,9 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,]. The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray): kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,]. The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (NVTE_Bias_Type): Type of attention bias. attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask. attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores. scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention. dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode. is_training (bool): Flag indicating whether the model is in training mode.
...@@ -2319,22 +2600,23 @@ def fused_attn_bwd( ...@@ -2319,22 +2600,23 @@ def fused_attn_bwd(
# For optional tensors, which custom calls doesn't support None # For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype) _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout: if qkv_layout.is_qkvpacked():
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" qkv_for_primitive = [*qkv, _not_used, _not_used]
qkv_for_primitive = [*qkv, _not_used, _not_used] elif qkv_layout.is_kvpacked():
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: assert (
assert ( len(qkv) == 2
len(qkv) == 2 ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" qkv_for_primitive = [*qkv, _not_used]
qkv_for_primitive = [*qkv, _not_used] elif qkv_layout.is_separate():
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: assert (
assert ( len(qkv) == 3
len(qkv) == 3 ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" qkv_for_primitive = qkv
qkv_for_primitive = qkv else:
raise ValueError(f"Unknown {qkv_layout=}")
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
...@@ -2356,10 +2638,12 @@ def fused_attn_bwd( ...@@ -2356,10 +2638,12 @@ def fused_attn_bwd(
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING: case CPStrategy.RING:
primitive = FusedRingAttnBwdPrimitive.outer_primitive if qkv_layout.is_thd():
primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
else:
primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primitive.bind( *qkv_grads, bias_grad = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
......
...@@ -229,6 +229,10 @@ static void FusedAttnForwardImpl( ...@@ -229,6 +229,10 @@ static void FusedAttnForwardImpl(
if (is_ragged) { if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
// Memset to 0xF0 for filling large negative numbers
auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads;
cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream);
} }
/* Output tensors */ /* Output tensors */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment