Unverified Commit c2c3d540 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Support segment_ids/pos as FA inputs (#1406)



* POC for segment_ids/segment_pos
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Change segment_pos position
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use RemainingArgs to solve number of parameters mismatches
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Test mask_descriptor for accomendating different mask representations
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix bugs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use descriptor in bwd
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Primitives only accepts pure jnp array
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* segment_ids/pos support POC
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move seqlens/offsets generation to mask descriptor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Rename MaskDescriptor to SequenceDescriptor
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Generalize get_seqlens_and_offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Utilize sequence desc on FA bwd
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Migrate to new API
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add docstrings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove small inputs and test different input format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix seed shardings
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Optimize sequence converting overhead
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Optimize seq_offsets calculation
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix up
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflicts
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove reduntant line
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 3d7ff1c6
......@@ -2,31 +2,18 @@
#
# See LICENSE for license information.
import pytest
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import (
generate_configs,
generate_context_parallel_configs,
generate_collectives_count,
compare_ops,
)
from utils import (
make_causal_mask,
make_self_mask,
assert_allclose,
print_debug_tensor_stats,
)
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
......@@ -36,10 +23,11 @@ from transformer_engine.jax.attention import (
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
import pytest
from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
DTYPES = [jnp.float16, jnp.bfloat16]
DTYPES = [jnp.bfloat16]
class TestDistributedSelfAttn:
......@@ -141,6 +129,7 @@ class TestDistributedSelfAttn:
QKVLayout.BS3HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -205,6 +194,7 @@ class TestDistributedCrossAttn:
QKVLayout.BSHD_BS2HD,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......@@ -293,6 +283,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout,
bias_shape,
None,
SeqDescFormat.Seqlens,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
......@@ -28,12 +28,11 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
fused_attn,
fused_attn_thd,
make_swa_mask,
SequenceDescriptor,
CPStrategy,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
......@@ -199,8 +198,8 @@ def get_seqlens_and_offsets(segment_ids):
).squeeze(-1)
offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
seqlens = jnp.where(seqlens, seqlens, -1)
return seqlens, offsets
......@@ -239,11 +238,7 @@ def customcall_fused_dpa(
key,
value,
bias,
mask,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
sequence_descriptor,
dropout_rng,
**kwargs,
):
......@@ -264,19 +259,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
if not qkv_layout.is_thd():
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args,
bias,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
).astype(query.dtype)
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)
class BiasShape(Enum):
......@@ -290,6 +275,12 @@ class BiasShape(Enum):
_11SS = "11SS"
class SeqDescFormat(Enum):
Mask = auto()
Seqlens = auto()
SegmentIDs = auto()
@dataclass
class FusedAttnRunner:
"""
......@@ -309,7 +300,8 @@ class FusedAttnRunner:
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
......@@ -327,11 +319,14 @@ class FusedAttnRunner:
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
else:
# +1 for testing runtime_segments < max_segments
return self.num_segments_per_seq + 1
return 1
def _check_configs(self):
# TODO(rewang): probably adds this in is_fused_attn_available
......@@ -462,11 +457,11 @@ class FusedAttnRunner:
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
# [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
# Not include paddings
max_segment_size = sequence_length // num_segments
......@@ -541,16 +536,47 @@ class FusedAttnRunner:
self.window_size,
)
# Test different input formats
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
match self.seq_desc_format:
case SeqDescFormat.Mask:
pytest.skip("THD doesn't support mask input")
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
(self.seqlens_q, self.seqlens_kv),
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
(self.segment_pos_q, self.segment_pos_kv),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
else:
self.mask_for_customcall = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.sequence_desciptor = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)
case SeqDescFormat.Seqlens:
self.sequence_desciptor = SequenceDescriptor.from_seqlens(
(
self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
),
)
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
......@@ -565,10 +591,21 @@ class FusedAttnRunner:
)
self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)
self.mask_pspec = PartitionSpec(
mask_pspec = PartitionSpec(
self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
)
self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec)
self.mask_sharding = NamedSharding(self.mesh, mask_pspec)
match self.seq_desc_format:
case SeqDescFormat.Mask:
self.seq_desc_sharding = self.mask_sharding
case _:
def to_dp_shardings(x):
pspec = PartitionSpec(self.mesh_resource.dp_resource)
return NamedSharding(self.mesh, pspec)
self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
if self.bias_shape == BiasShape._1HSS:
self.bias_pspec = PartitionSpec(
......@@ -631,11 +668,7 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
......@@ -659,11 +692,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
)
......@@ -722,11 +751,7 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.mask_for_customcall, self.mask_sharding),
jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
......@@ -768,11 +793,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
......@@ -883,10 +904,7 @@ class FusedAttnRunner:
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
pytest.param(
2,
2048,
......@@ -897,8 +915,8 @@ class FusedAttnRunner:
jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
],
)
@pytest.mark.parametrize(
......@@ -915,6 +933,14 @@ class FusedAttnRunner:
pytest.param(True, id="SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Mask, id="Mask"),
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
class TestFusedAttn:
"""
Fused attention tester
......@@ -953,6 +979,7 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
......@@ -977,6 +1004,7 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_forward()
......@@ -1002,6 +1030,7 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
......@@ -1024,5 +1053,6 @@ class TestFusedAttn:
qkv_layout,
bias_shape,
window_size,
seq_desc_format,
)
runner.test_backward()
This diff is collapsed.
......@@ -213,14 +213,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static void FusedAttnForwardImpl(
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output,
void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups,
size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens,
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
......@@ -303,11 +303,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
void *k = buffers[1];
void *v = buffers[2];
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];
void *seed = buffers[4];
void *q_cu_seqlens = buffers[5];
void *kv_cu_seqlens = buffers[6];
void *q_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[8] : nullptr;
/* Output buffer from XLA */
void *output = buffers[9];
......@@ -316,7 +316,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
void *workspace = buffers[12];
FusedAttnForwardImpl(
stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed,
stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets,
output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch,
descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads,
descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim,
......@@ -354,24 +354,24 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
Buffer_Type v_buf, Buffer_Type bias_buf,
Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Buffer_Type seed_buf, Result_Type output_buf,
Variadic_Buffer_Type _unused_args, Result_Type output_buf,
Result_Type softmax_aux_buf, Result_Type rng_state_buf,
Result_Type workspace_buf, Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnForwardImpl(
stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(),
bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(),
is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(),
output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(),
kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr,
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
return ffi_with_cuda_error_check();
}
......@@ -383,11 +383,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.Arg<Buffer_Type>() // k
.Arg<Buffer_Type>() // v
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // seed_buf
.Arg<Buffer_Type>() // q_cu_seqlens
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.Arg<Buffer_Type>() // seed_buf
.RemainingArgs() // _cp_aux_args unused
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // softmax_aux
.Ret<Buffer_Type>() // rng_state
......@@ -642,9 +643,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
Buffer_Type output_buf, Buffer_Type doutput_buf,
Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf,
Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf,
Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
Dictionary attrs) {
Variadic_Buffer_Type _unused_args, Result_Type dq_buf,
Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf,
Result_Type workspace_buf, Dictionary attrs) {
FUSED_ATTN_FFI_GET_ATTRS;
FusedAttnBackwardImpl(
......@@ -677,6 +678,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.Arg<Buffer_Type>() // kv_cu_seqlens
.Arg<Buffer_Type>() // q_seq_offsets
.Arg<Buffer_Type>() // k_seq_offsets
.RemainingArgs() // _cp_aux_args unused
.Ret<Buffer_Type>() // dq
.Ret<Buffer_Type>() // dk
.Ret<Buffer_Type>() // dv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment