Unverified Commit c5d6a069 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] THD ring attention (#1454)



* Support THD + ring attention for self attn
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Consolidate reorder strategy
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix dataclass frozen issue
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove redundant code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extension
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine P2P helper check_supported
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add segment_ids/pos check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fixup
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dual chunk swap example
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Align different reorder code structure
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4b523d29
......@@ -23,6 +23,7 @@ from transformer_engine.jax.attention import (
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
ReorderStrategy,
)
......@@ -210,29 +211,29 @@ class TestDistributedCrossAttn:
"data_shape",
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest.param([2, 128, 12, 128], id="2-128xCP-12-128"),
pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"),
pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize("kv_groups", [1, 8])
@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")])
@pytest.mark.parametrize(
"attn_mask_type",
"qkv_layout, attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"),
pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"),
pytest.param(
QKVLayout.THD_THD_THD,
AttnMaskType.PADDING_CAUSAL_MASK,
id="THD_SEPARATE-PADDING_CAUSAL",
),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
[pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")],
)
class TestDistributedContextParallelSelfAttn:
......@@ -265,7 +266,6 @@ class TestDistributedContextParallelSelfAttn:
data_shape = batch, seqlen, num_head, hidden
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)
runner = FusedAttnRunner(
batch,
......@@ -282,7 +282,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
bias_shape,
None,
SeqDescFormat.Seqlens,
SeqDescFormat.SegmentIDs,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -297,7 +297,7 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
mask_type,
dropout_prob,
num_head,
num_kv_heads,
......@@ -340,6 +340,8 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
load_balanced,
):
if qkv_layout.is_thd():
pytest.skip("THD doesn't support all gather context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
......@@ -377,7 +379,10 @@ class TestDistributedContextParallelSelfAttn:
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
self.impl_test_context_parallel_attn(
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -404,17 +409,26 @@ class TestReorderCausalLoadBalancing:
],
)
@pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD])
def test(self, cp_size, shape, qkv_format):
@pytest.mark.parametrize(
"reorder_strategy",
[
pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"),
pytest.param(ReorderStrategy.Striped, id="Striped"),
],
)
def test(self, cp_size, shape, qkv_format, reorder_strategy):
tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16)
seq_dim = 1
if qkv_format == QKVFormat.SBHD:
tensor = tensor.swapaxes(0, 1)
seq_dim = 0
ref = tensor.copy()
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2])
reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3])
inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3])
reordered = reorder(tensor, cp_size, qkv_format)
inversed = inverse(reordered, cp_size, qkv_format)
reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim)
inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim)
assert jnp.array_equal(inversed, ref)
......@@ -28,12 +28,14 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
ReorderStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
......@@ -347,9 +349,9 @@ class FusedAttnRunner:
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
self.qkv_layout.value,
self.attn_bias_type.value,
self.attn_mask_type.value,
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
......@@ -500,7 +502,8 @@ class FusedAttnRunner:
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD:
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
......@@ -536,6 +539,30 @@ class FusedAttnRunner:
self.window_size,
)
if self.cp_size > 1 and self.cp_load_balanced:
if self.qkv_layout.is_thd():
reorder_strategy = ReorderStrategy.Striped
else:
reorder_strategy = ReorderStrategy.DualChunkSwap
seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
# Test different input formats
if self.qkv_layout.is_thd():
match self.seq_desc_format:
......@@ -548,8 +575,14 @@ class FusedAttnRunner:
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
(self.segment_pos_q, self.segment_pos_kv),
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
......@@ -605,7 +638,12 @@ class FusedAttnRunner:
case _:
def to_dp_shardings(x):
if x.ndim == 1:
pspec = PartitionSpec(self.mesh_resource.dp_resource)
else:
pspec = PartitionSpec(
self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
)
return NamedSharding(self.mesh, pspec)
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
......@@ -637,24 +675,6 @@ class FusedAttnRunner:
self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)
# Softmax aux sharding
if self.cp_size > 1 and self.cp_load_balanced:
self.cp_reorder_fn = partial(
reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
cp_size=self.cp_size,
tensor_format=self.qkv_layout.get_qkv_format(),
)
else:
# no-ops for non cp or non load balanced
self.cp_reorder_fn = lambda x: x
self.cp_inverse_reorder_fn = lambda x: x
def test_forward(self):
"""
Test forward without JIT
......@@ -733,14 +753,23 @@ class FusedAttnRunner:
self._setup_inputs()
def grad_func(func, *args, **kwargs):
def grad_func(func, *args, cp_reverse_out=False, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if self.attn_mask_type.is_causal():
gradient_multiplier /= 10
# Keep only valid result for the gradient
if not cp_reverse_out:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
func(*args, **kwargs),
)
else:
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
self.pad_q[..., jnp.newaxis, jnp.newaxis],
0,
self.cp_inverse_reorder_fn(func(*args, **kwargs)),
)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
......@@ -787,7 +816,7 @@ class FusedAttnRunner:
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
),
arg_nums,
),
......
......@@ -135,6 +135,39 @@ class QKVLayout(Enum):
"""
return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
def to_qkvpacked(self):
"""
Return the corresponding qkvpacked format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BS3HD
if qkv_format == QKVFormat.THD:
return QKVLayout.T3HD
raise ValueError(f"Unsupported {qkv_format=}")
def to_kvpacked(self):
"""
Return the corresponding kvpacked format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BSHD_BS2HD
if qkv_format == QKVFormat.THD:
return QKVLayout.THD_T2HD
raise ValueError(f"Unsupported {qkv_format=}")
def to_separate(self):
"""
Return the corresponding separate format, useful when adjusting q, k, v layout
"""
qkv_format = self.get_qkv_format()
if qkv_format == QKVFormat.BSHD:
return QKVLayout.BSHD_BSHD_BSHD
if qkv_format == QKVFormat.THD:
return QKVLayout.THD_THD_THD
raise ValueError(f"Unsupported {qkv_format=}")
class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
......@@ -149,6 +182,28 @@ class CPStrategy(Enum):
RING = 2
class ReorderStrategy(Enum):
"""
Defines the tokens re-order strategy for context parallel load balancing for causal mask.
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
mulitple of 2 * cp_size.
Examples:
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
- Striped: This strategy distributes the tokens in a striped (interleaved) manner across
the sequence. This is currently used for THD load balance.
Example: Consider 4 GPUs with seqlens=16.
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15]
- After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15]
"""
DualChunkSwap = 0
Striped = 1
def make_swa_mask(
segment_pos_q: jnp.ndarray,
segment_pos_kv: jnp.ndarray,
......@@ -243,9 +298,9 @@ def is_fused_attn_kernel_available(
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
qkv_layout.value,
attn_bias_type.value,
attn_mask_type.value,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_probability,
q_num_heads,
kv_num_heads,
......@@ -276,16 +331,24 @@ def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
return batch, q_max_seqlen, kv_max_seqlen
def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int):
"""Reorders a tensor for load balancing the compute of causal attention."""
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False)
raise ValueError(f"Unsupported {strategy=}")
def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat):
def inverse_reorder_causal_load_balancing(
tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int
):
"""Inverse operation of `reorder_causal_load_balancing`."""
seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True)
raise ValueError(f"Unsupported {strategy=}")
def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
......@@ -412,8 +475,6 @@ class SequenceDescriptor:
"""
Acquire the seqlens/offsets for cuDNN backend
"""
attn_mask_type = AttnMaskType(attn_mask_type)
qkv_layout = QKVLayout(qkv_layout)
q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos
assert q_segment_ids.shape == q_segment_pos.shape
......@@ -589,9 +650,9 @@ def _legacy_fused_attn(
Intra-sequence padding is not valid. The padded tokens can only on the right-most.
Otherwise the results will be wrong.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
......@@ -608,16 +669,18 @@ def _legacy_fused_attn(
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
case QKVLayout.BS3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
case QKVLayout.BSHD_BS2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
case QKVLayout.BSHD_BSHD_BSHD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
case _:
raise ValueError(f"Unknown {qkv_layout=}")
# convert the mask to seqlens, mask doesn't support ragged offsets
if not attn_mask_type.is_padding():
......@@ -689,16 +752,18 @@ def fused_attn_thd(
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
case QKVLayout.T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_T2HD:
case QKVLayout.THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
case QKVLayout.THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
case _:
raise ValueError(f"Unknown {qkv_layout=}")
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
assert q_seq_lens.shape == (batch, q_max_seqlen)
......@@ -789,9 +854,9 @@ def _fused_attn_fwd_rule(
bias,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
......@@ -845,9 +910,9 @@ def _fused_attn_bwd_rule(
output,
dz,
sequence_descriptor,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
......@@ -903,9 +968,9 @@ def fused_attn(
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
......
......@@ -229,6 +229,10 @@ static void FusedAttnForwardImpl(
if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
// Memset to 0xF0 for filling large negative numbers
auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads;
cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream);
}
/* Output tensors */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment