Unverified Commit c2c3d540 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Support segment_ids/pos as FA inputs (#1406)



* POC for segment_ids/segment_pos
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change segment_pos position
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use RemainingArgs to solve number of parameters mismatches
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Test mask_descriptor for accomendating different mask representations
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix bugs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use descriptor in bwd
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Primitives only accepts pure jnp array
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* segment_ids/pos support POC
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move seqlens/offsets generation to mask descriptor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename MaskDescriptor to SequenceDescriptor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Generalize get_seqlens_and_offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Utilize sequence desc on FA bwd
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Migrate to new API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add docstrings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove small inputs and test different input format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix seed shardings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Optimize sequence converting overhead
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Optimize seq_offsets calculation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix up
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflicts
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove reduntant line
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 3d7ff1c6
......@@ -2,31 +2,18 @@
#
# See LICENSE for license information.
import pytest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import (
make_causal_mask,
make_self_mask,
assert_allclose,
print_debug_tensor_stats,
)
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
......@@ -36,10 +23,11 @@ from transformer_engine.jax.attention import (
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
import pytest
from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
DTYPES = [jnp.float16, jnp.bfloat16]
DTYPES = [jnp.bfloat16]
class TestDistributedSelfAttn:
......@@ -141,6 +129,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -205,6 +194,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -293,6 +283,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
......@@ -28,12 +28,11 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
fused_attn_thd,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
......@@ -199,8 +198,8 @@ def get_seqlens_and_offsets(segment_ids):
).squeeze(-1)
offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
seqlens = jnp.where(seqlens, seqlens, -1)
return seqlens, offsets
......@@ -239,11 +238,7 @@ def customcall_fused_dpa(
key,
value,
bias,
mask,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
sequence_descriptor,
dropout_rng,
**kwargs,
):
......@@ -264,19 +259,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
if not qkv_layout.is_thd():
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args,
bias,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
).astype(query.dtype)
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)
class BiasShape(Enum):
......@@ -290,6 +275,12 @@ class BiasShape(Enum):
_11SS = "11SS"
class SeqDescFormat(Enum):
Mask = auto()
Seqlens = auto()
SegmentIDs = auto()
@dataclass
class FusedAttnRunner:
"""
......@@ -309,7 +300,8 @@ class FusedAttnRunner:
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
......@@ -327,11 +319,14 @@ class FusedAttnRunner:
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
else:
return 1
def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
......@@ -462,11 +457,11 @@ class FusedAttnRunner:
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
# [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
# Not include paddings
max_segment_size = sequence_length // num_segments
......@@ -541,16 +536,47 @@ class FusedAttnRunner:
self.window_size,
)
# Test different input formats
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
match self.seq_desc_format:
case SeqDescFormat.Mask:
pytest.skip("THD doesn't support mask input")
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
(self.seqlens_q, self.seqlens_kv),
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
(self.segment_pos_q, self.segment_pos_kv),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
else:
self.mask_for_customcall = make_mask(
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens(
(
self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
......@@ -565,10 +591,21 @@ class FusedAttnRunner:
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
self.mask_pspec = PartitionSpec(
mask_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec)
self.mask_sharding = NamedSharding(self.mesh, mask_pspec)
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.seq_desc_sharding = self.mask_sharding
case _:
def to_dp_shardings(x):
pspec = PartitionSpec(self.mesh_resource.dp_resource)
return NamedSharding(self.mesh, pspec)
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
......@@ -631,11 +668,7 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
......@@ -659,11 +692,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
)
......@@ -722,11 +751,7 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
......@@ -768,11 +793,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
......@@ -883,10 +904,7 @@ class FusedAttnRunner:
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
pytest.param(
2,
2048,
......@@ -897,8 +915,8 @@ class FusedAttnRunner:
jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
],
)
@pytest.mark.parametrize(
......@@ -915,6 +933,14 @@ class FusedAttnRunner:
pytest.param(True, id="SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Mask, id="Mask"),
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
class TestFusedAttn:
"""
Fused attention tester
......@@ -953,6 +979,7 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
......@@ -977,6 +1004,7 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_forward()
......@@ -1002,6 +1030,7 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
......@@ -1024,5 +1053,6 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_backward()
......@@ -2,13 +2,16 @@
#
# See LICENSE for license information.
"""JAX multi-head attention modules"""
from __future__ import annotations
from enum import Enum
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import warnings
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from flax.linen import make_attention_mask
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
......@@ -252,27 +255,23 @@ def is_fused_attn_kernel_available(
(-1, -1) if window_size is None else window_size,
)
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False
return True
return make_helper(attn_mask_type).is_fused_attn_kernel_available()
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
match qkv_layout:
case QKVLayout.BS3HD | QKVLayout.T3HD:
if qkv_layout.is_qkvpacked():
assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
batch, q_max_seqlen, *_ = qkv[0].shape
kv_max_seqlen = q_max_seqlen
case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
elif qkv_layout.is_kvpacked():
assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
batch, q_max_seqlen, *_ = qkv[0].shape
kv_max_seqlen = qkv[1].shape[1]
case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
elif qkv_layout.is_separate():
assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
batch, q_max_seqlen, *_ = qkv[0].shape
kv_max_seqlen = qkv[1].shape[1]
case _:
else:
raise ValueError(f"Unsupported {qkv_layout=}")
return batch, q_max_seqlen, kv_max_seqlen
......@@ -289,7 +288,273 @@ def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: Q
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
def fused_attn(
def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq):
# bincount map with 0s
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
seqlens = seqlens_with_zero[..., 1:]
def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = x[..., :1] != 0
same_as_previous = jnp.hstack((first_column, same_as_previous))
return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))(
same_as_previous
).squeeze(-1)
offsets = _find_offsets(segment_ids)
return seqlens, offsets
def _mask_to_seqlens_offset(mask, max_segments_per_seq):
assert mask.shape[1] == 1
row_ids = mask.squeeze(axis=1).max(axis=-1)
q_seqlen, q_offset = _get_seqlens_and_offsets(row_ids, max_segments_per_seq)
col_ids = mask.squeeze(axis=1).max(axis=-2)
kv_seqlen, kv_offset = _get_seqlens_and_offsets(col_ids, max_segments_per_seq)
return q_seqlen, q_offset, kv_seqlen, kv_offset
def _segment_ids_pos_to_seqlens_offsets(
segment_ids_q,
segment_ids_kv,
segment_pos_q,
segment_pos_kv,
attn_mask_type,
window_size,
max_segments_per_seq,
):
# (1 = attend, 0 = masked)
segment_mask = make_attention_mask(
segment_ids_q,
segment_ids_kv,
jnp.equal,
)
segment_mask_with_id = make_attention_mask(
segment_ids_q,
segment_ids_kv,
lambda x, y: jnp.equal(x, y) * x,
)
attn_mask = segment_mask
if attn_mask_type.is_causal():
causal_mask = make_attention_mask(
segment_pos_q,
segment_pos_kv,
jnp.greater_equal,
)
attn_mask = jnp.logical_and(segment_mask, causal_mask)
swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool)
attn_mask = jnp.logical_and(attn_mask, swa_mask)
attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0)
q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset(
attn_mask_with_id, max_segments_per_seq
)
return q_seqlen, kv_seqlen, q_offset, kv_offset
def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
# convert the mask to seqlens, mask doesn't support ragged offsets
if not attn_mask_type.is_padding():
q_max_seqlen = segment_ids_q.shape[-1]
kv_max_seqlen = segment_ids_kv.shape[-1]
q_seq_lens = jnp.full_like(q_max_seqlen, q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full_like(kv_max_seqlen, kv_max_seqlen, dtype=jnp.int32)
else:
q_seq_lens = jnp.sum(segment_ids_q, axis=-1).astype(jnp.int32)
kv_seq_lens = jnp.sum(segment_ids_kv, axis=-1).astype(jnp.int32)
return q_seq_lens, kv_seq_lens
@jax.tree_util.register_pytree_node_class
class SequenceDescriptor:
"""A class to descibe the sequences with flexible initialization.
- SequenceDescriptor.from_seqlens
For non-THD (non-packed) cases, where each batch has only 1 sequence.
- SequenceDescriptor.from_seqlens_and_offsets
For THD (packed) cases, where each batch may have not only 1 sequence.
- SequenceDescriptor.from_segment_ids_and_pos
Experimental feature for THD (packed) cases with context parallelism.
"""
seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
seq_offsets: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
segment_ids: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
segment_pos: Optional[Tuple[jnp.ndarray, jnp.ndarray]]
def __init__(self, seqlens=None, seq_offsets=None, segment_ids=None, segment_pos=None):
"""
Initialize to Tuple(jnp.zeros, jnp.zeros) because the primitive only accepts pure jax array
"""
self.seqlens = (jnp.zeros(0), jnp.zeros(0)) if seqlens is None else seqlens
self.seq_offsets = (jnp.zeros(0), jnp.zeros(0)) if seq_offsets is None else seq_offsets
self.segment_ids = (jnp.zeros(0), jnp.zeros(0)) if segment_ids is None else segment_ids
self.segment_pos = (jnp.zeros(0), jnp.zeros(0)) if segment_pos is None else segment_pos
def tree_flatten(self):
"""
Flatten method to register as a pytree node
"""
return ((self.seqlens, self.seq_offsets, self.segment_ids, self.segment_pos), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""
Unflatten method to register as a pytree node
"""
del aux_data
return cls(*children)
def get_seqlens_and_offsets(
self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq
):
"""
Acquire the seqlens/offsets for cuDNN backend
"""
attn_mask_type = AttnMaskType(attn_mask_type)
qkv_layout = QKVLayout(qkv_layout)
q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos
assert q_segment_ids.shape == q_segment_pos.shape
assert kv_segment_ids.shape == kv_segment_pos.shape
# No segment_ids/segment_pos
if q_segment_ids.size + kv_segment_ids.size == 0:
return self.seqlens, self.seq_offsets
if qkv_layout.is_thd():
q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets(
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
attn_mask_type,
window_size,
max_segments_per_seq,
)
else:
q_seqlens, kv_seqlens = _segment_ids_to_seqlens(
q_segment_ids,
kv_segment_ids,
attn_mask_type,
)
q_offsets = kv_offsets = jnp.zeros(0)
return (q_seqlens, kv_seqlens), (q_offsets, kv_offsets)
@classmethod
def _expand_to_pair(
cls, value: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Internal helper to ensure a single value expands into a pair (q, kv).
"""
if isinstance(value, tuple):
if len(value) != 2:
raise ValueError("Input tuple must have exactly 2 elements.")
return value
if isinstance(value, jnp.ndarray):
return value, value # Duplicate for q=kv case
raise TypeError(
"Expected a jax.numpy.ndarray or a tuple of two jax.numpy.ndarray, "
f"but got {type(value).__name__}."
)
@classmethod
def from_seqlens(
cls,
seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
) -> SequenceDescriptor:
"""
Factory method for inputs with sequence lengths only (non-THD).
Args:
seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens):
- q_seqlens (jnp.ndarray):
Sequence lengths for the query, with shape [batch].
- kv_seqlen (jnp.ndarray):
Sequence lengths for the key and value, with shape [batch].
Return:
A SequenceDescriptor with only seqlens initialized.
"""
q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens)
return cls(seqlens=(q_seqlens, kv_seqlens))
@classmethod
def from_seqlens_and_offsets(
cls,
seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
seq_offsets: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
) -> SequenceDescriptor:
"""
Factory method for inputs with sequence lengths and offsets (THD).
Args:
seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens):
- q_seqlens (jnp.ndarray):
Sequence lengths for the query, with shape [batch, max_seqlen].
Unused positions are padded with -1.
- kv_seqlen (jnp.ndarray):
Sequence lengths for the key and value, with shape [batch, max_seqlen].
Unused positions are padded with -1.
seq_offsets(Tuple(jnp.ndarray, jnp.ndarray)) = (q_offsets, kv_offsets)
- q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
- kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
Return:
A SequenceDescriptor with seqlens/seq_offsets initialized.
"""
q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens)
q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets)
return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets))
@classmethod
def from_segment_ids_and_pos(
cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
) -> SequenceDescriptor:
"""
Experimental factory method for inputs with segment IDs and optional positions. (THD)
Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray):
Query segment ids start with 1, with shape [batch, max_seqlen].
0s are treated as paddings.
- kv_segment_ids (jnp.ndarray):
Key, value segment ids start with 1, with shape [batch, max_seqlen].
0s are treated as paddings.
segment_pos(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_pos, kv_segment_pos)
- q_segment_pos (jnp.ndarray):
The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen].
Return:
A SequenceDescriptor with segment_ids/segment_pos initialized.
"""
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)
if segment_pos is not None:
segment_pos = cls._expand_to_pair(segment_pos)
else:
def generate_default_pos(segment_ids):
seqlen = segment_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)
q_seg_pos = generate_default_pos(q_seg_ids)
kv_seg_pos = generate_default_pos(kv_seg_ids)
segment_pos = (q_seg_pos, kv_seg_pos)
return cls(
segment_ids=(q_seg_ids, kv_seg_ids),
segment_pos=segment_pos,
)
def _legacy_fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
mask: Optional[jnp.ndarray],
......@@ -372,10 +637,7 @@ def fused_attn(
output = _fused_attn(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
None,
None,
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -414,63 +676,13 @@ def fused_attn_thd(
context_parallel_axis: str = "",
):
"""
(Experimental) Perform THD (packed) cuDNN fused attention.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
q_seqlen (jnp.ndarray):
Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
padded with -1.
kv_seqlen (jnp.ndarray):
Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
are padded with -1.
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]):
Sliding window size.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
Examples:
>>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
>>> b, s, h, d = 2, 4, 12, 64
>>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
>>> # 3 segments in first seq, 2 segments in second seq
>>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
>>> # seq_offsets need to include the end offset of the last segments
>>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
>>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
q_seq_offsets, kv_seq_offsets, None,
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3)
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
"""
warnings.warn(
"fused_attn_thd is deprecated, please use fused_attn with SequenceDescriptor",
DeprecationWarning,
)
assert (
qkv_layout.is_thd()
), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."
......@@ -497,10 +709,9 @@ def fused_attn_thd(
output = _fused_attn(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
SequenceDescriptor.from_seqlens_and_offsets(
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
),
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -518,15 +729,12 @@ def fused_attn_thd(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
q_seq_lens: jnp.ndarray,
kv_seq_lens: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
seed: jnp.ndarray,
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
......@@ -542,10 +750,7 @@ def _fused_attn(
output, _ = _fused_attn_fwd_rule(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
sequence_descriptor,
seed,
attn_bias_type,
attn_mask_type,
......@@ -565,10 +770,7 @@ def _fused_attn(
def _fused_attn_fwd_rule(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
sequence_descriptor,
seed,
attn_bias_type,
attn_mask_type,
......@@ -585,10 +787,7 @@ def _fused_attn_fwd_rule(
output, softmax_aux, rng_state = tex.fused_attn_fwd(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
......@@ -608,10 +807,7 @@ def _fused_attn_fwd_rule(
return output, (
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
sequence_descriptor,
softmax_aux,
rng_state,
output,
......@@ -636,10 +832,7 @@ def _fused_attn_bwd_rule(
(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
sequence_descriptor,
softmax_aux,
rng_state,
output,
......@@ -651,10 +844,7 @@ def _fused_attn_bwd_rule(
rng_state,
output,
dz,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
sequence_descriptor,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
......@@ -669,7 +859,137 @@ def _fused_attn_bwd_rule(
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None, None, None, None
return (
grad_qkv,
grad_bias,
None,
None,
)
_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
def fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
Perform cuDNN fused attention.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]):
Sliding window size.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
Examples (non-THD, also known as non-packed):
>>> # q_segment_ids = [[1, 1, 1, 0], [1, 1, 0, 0]], 0 means padded tokens
>>> # kv_segment_ids = [[1, 0, 0, 0], [1, 1, 0, 0]], 0 means padded tokens
>>> b, s, h, d = 2, 4, 12, 64
>>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
>>> q_seq_lens = jnp.asarray([3, 2])
>>> kv_seq_lens = jnp.asarray([1, 2])
>>> sequence_desc = SequenceDescriptor.from_seqlens(
seqlens=(q_seq_lens, kv_seq_lens))
>>> out = fused_attn((qkv,), None, sequence_desc, None,
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.BS3HD, 0.125, 0, True, 3)
Examples (THD, also known as packed):
>>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
>>> # segment_pos = [[0, 1, 0, 0], [0, 1, 0, 1]]
>>> b, s, h, d = 2, 4, 12, 64
>>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
>>> # 3 segments in first seq, 2 segments in second seq
>>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
>>> # seq_offsets need to include the end offset of the last segments
>>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
>>> sequence_desc = SequenceDescriptor.from_seqlens_and_offsets(
seqlens=(q_seq_lens, kv_seq_lens),
seq_offsets=(q_seq_offsets, kv_seq_offsets))
>>> out = fused_attn((qkv,), None, sequence_desc, None,
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3)
"""
if isinstance(sequence_descriptor, jnp.ndarray):
warnings.warn(
"Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. "
+ "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.",
DeprecationWarning,
)
if max_segments_per_seq != 1:
raise ValueError("Passing mask is only supported for non-THD case.")
return _legacy_fused_attn(
qkv,
bias,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
output = _fused_attn(
qkv,
bias,
sequence_descriptor,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
return output
......@@ -17,7 +17,7 @@ from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine.jax.attention import CPStrategy
from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
......@@ -211,9 +211,8 @@ def generate_cu_seqlen(actual_seqlen):
"""
Generating cumsum seqlen for a batch
"""
cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1)
cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen)
cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1)
actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
return cu_seqlen
......@@ -224,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (9,)
impl_static_args = (13,)
inner_primitive = None
outer_primitive = None
......@@ -234,11 +233,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval,
v_aval,
bias_aval,
seed_aval,
q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
seed_aval,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
*,
config: _FusedAttnConfig,
):
......@@ -354,11 +357,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
*,
config: _FusedAttnConfig,
):
......@@ -387,11 +394,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
......@@ -417,11 +428,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
......@@ -466,15 +477,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config: _FusedAttnConfig,
):
assert FusedAttnFwdPrimitive.inner_primitive is not None
sequence_descriptor = SequenceDescriptor(
seqlens=(q_seqlen, kv_seqlen),
seq_offsets=(q_seq_offsets, k_seq_offsets),
segment_ids=(_q_segment_ids, _kv_segment_ids),
segment_pos=(_q_segment_pos, _kv_segment_pos),
)
(q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
sequence_descriptor.get_seqlens_and_offsets(
config.attn_mask_type,
config.qkv_layout,
config.window_size,
config.max_segments_per_seq,
)
)
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition, fill_value=-1):
......@@ -517,6 +548,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
......@@ -524,15 +556,17 @@ class FusedAttnFwdPrimitive(BasePrimitive):
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)
# Set the unused position to max size (batch * max_seqlen)
# And set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)
q_seq_offsets = _fix_len_take(
q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
)
k_seq_offsets = _fix_len_take(
k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
......@@ -542,11 +576,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=config,
)
return output, softmax_aux, rng_state
......@@ -555,7 +593,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, *_, seed_bdim = batch_dims
q_bdim, _, _, _, seed_bdim, *_ = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return (
......@@ -600,7 +638,9 @@ class FusedAttnFwdPrimitive(BasePrimitive):
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings
......@@ -616,7 +656,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (12,)
impl_static_args = (16,)
inner_primitive = None
outer_primitive = None
......@@ -634,6 +674,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
*,
config,
):
......@@ -718,6 +762,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
*,
config,
):
......@@ -754,6 +802,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
......@@ -839,10 +891,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config,
):
assert FusedAttnBwdPrimitive.inner_primitive is not None
sequence_descriptor = SequenceDescriptor(
seqlens=(q_seqlen, kv_seqlen),
seq_offsets=(q_seq_offsets, k_seq_offsets),
segment_ids=(_q_segment_ids, _kv_segment_ids),
segment_pos=(_q_segment_pos, _kv_segment_pos),
)
(q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
sequence_descriptor.get_seqlens_and_offsets(
config.attn_mask_type,
config.qkv_layout,
config.window_size,
config.max_segments_per_seq,
)
)
if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition, fill_value=-1):
......@@ -893,15 +965,17 @@ class FusedAttnBwdPrimitive(BasePrimitive):
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)
# Set the unused position to max size (batch * max_seqlen)
# And set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)
q_seq_offsets = _fix_len_take(
q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
)
k_seq_offsets = _fix_len_take(
k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
......@@ -919,6 +993,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=config,
)
return dq, dk, dv, dbias
......@@ -975,6 +1053,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
......@@ -989,6 +1071,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=config,
)
global_dbias = local_dbias
......@@ -1240,10 +1326,26 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed):
def impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
......@@ -1280,11 +1382,15 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_unmasked,
v_unmasked,
bias,
seed,
q_seqlen_for_step,
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(),
)
results.append((output, softmax_aux, rng_state))
......@@ -1357,13 +1463,31 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
def _cross_attn_bwd(
idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen
idx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
......@@ -1402,6 +1526,10 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(),
)
......@@ -1433,6 +1561,10 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
doutput,
q_seqlen,
kv_seqlen,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
)
for idx in range(cp_size)
]
......@@ -1595,7 +1727,9 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def ring_attn_fwd_impl(
......@@ -1603,11 +1737,15 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
_not_used = jnp.zeros(0, dtype=v.dtype)
......@@ -1644,12 +1782,16 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv,
_not_used,
bias,
seed,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
seed,
helper.get_step_config(attn_mask_type),
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(attn_mask_type),
)
return output_per_step, softmax_aux_per_step
......@@ -1665,11 +1807,15 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part,
_not_used,
bias,
seed,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
return output_per_step, softmax_aux_per_step
......@@ -1683,11 +1829,15 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv,
_not_used,
bias,
seed,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
seed,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
......@@ -1805,6 +1955,10 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
_not_used = jnp.zeros(0, dtype=output.dtype)
......@@ -1849,6 +2003,10 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(attn_mask_type),
)
return dq_per_step, dk_dv_per_step, dbias_per_step
......@@ -1873,6 +2031,10 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
dk_dv_per_step = jnp.concat(
......@@ -1907,6 +2069,10 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
)
dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
......@@ -1975,10 +2141,7 @@ def _maybe_context_parallel_axis(cp_axis: str):
def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
......@@ -2031,14 +2194,9 @@ def fused_attn_fwd(
(jnp.ndarray): The output tensor from the fused attention.
"""
seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
assert (q_seq_offsets is None) == (
kv_seq_offsets is None
), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
......@@ -2071,21 +2229,19 @@ def fused_attn_fwd(
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
primative = None
primitive = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING:
primative = FusedRingAttnFwdPrimitive.outer_primitive
primitive = FusedRingAttnFwdPrimitive.outer_primitive
return primative.bind(
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
return primitive.bind(
*qkv_for_primitive,
bias,
q_seqlen,
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
seed,
*seq_desc_flatten,
config=fused_config,
)
......@@ -2097,10 +2253,7 @@ def fused_attn_bwd(
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
......@@ -2155,12 +2308,6 @@ def fused_attn_bwd(
same format as the input `qkv`.
- The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
"""
assert (q_seq_offsets is None) == (
kv_seq_offsets is None
), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
......@@ -2196,24 +2343,23 @@ def fused_attn_bwd(
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
)
primative = None
primitive = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING:
primative = FusedRingAttnBwdPrimitive.outer_primitive
primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primative.bind(
*qkv_grads, bias_grad = primitive.bind(
*qkv_for_primitive,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
*seq_desc_flatten,
config=fused_config,
)
return tuple(qkv_grads[: len(qkv)]), bias_grad
......@@ -213,14 +213,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static void FusedAttnForwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output,
void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
......@@ -303,11 +303,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];
void *seed = buffers[4];
void *q_cu_seqlens = buffers[5];
void *kv_cu_seqlens = buffers[6];
void *q_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[8] : nullptr;
/* Output buffer from XLA */
void *output = buffers[9];
......@@ -316,7 +316,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
void *workspace = buffers[12];
FusedAttnForwardImpl(
stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed,
stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets,
output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch,
descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads,
descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
......@@ -354,24 +354,24 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Buffer_Type seed_buf, Result_Type output_buf,
Variadic_Buffer_Type _unused_args, Result_Type output_buf,
Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
return ffi_with_cuda_error_check();
}
......@@ -383,11 +383,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // seed_buf
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.Arg<Buffer_Type>() // seed_buf
.RemainingArgs() // _cp_aux_args unused
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state
......@@ -642,9 +643,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
Buffer_Type output_buf, Buffer_Type doutput_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
Dictionary attrs) {
Variadic_Buffer_Type _unused_args, Result_Type dq_buf,
Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf,
Result_Type workspace_buf, Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnBackwardImpl(
......@@ -677,6 +678,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.RemainingArgs() // _cp_aux_args unused
.Ret<Buffer_Type>() // dq
.Ret<Buffer_Type>() // dk
.Ret<Buffer_Type>() // dv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment