Unverified Commit 1ddfa0c6 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (#1851)



* Add support for Fused Attn MLA head_dim_qk != head_dim_v
	Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v
	Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()
	Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v
	Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v
	Modify DotProductAttention call() to extract head dims separately for qk and v
	Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API
	Add test case for head_dim_qk != head_dim_v (failing)
	Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix context dims in general DPA in test_fused_attn
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Fix dim for output tensor by replacing with v head dim rather than q head dim
Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Modify the fused attn jax unit test case for head dim qk != head dim v
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Fix usage of is_fused_attn signature in distributed tests
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Remove unnecessary assert
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 71c76b6b
...@@ -80,6 +80,7 @@ class TestDistributedSelfAttn: ...@@ -80,6 +80,7 @@ class TestDistributedSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, # no window None, # no window
): ):
pytest.skip("No FusedAttn backend found") pytest.skip("No FusedAttn backend found")
...@@ -99,6 +100,7 @@ class TestDistributedSelfAttn: ...@@ -99,6 +100,7 @@ class TestDistributedSelfAttn:
num_head, num_head,
num_head, num_head,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -227,6 +229,7 @@ class TestDistributedCrossAttn: ...@@ -227,6 +229,7 @@ class TestDistributedCrossAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, # no window None, # no window
): ):
pytest.skip("No FusedAttn backend found") pytest.skip("No FusedAttn backend found")
...@@ -239,6 +242,7 @@ class TestDistributedCrossAttn: ...@@ -239,6 +242,7 @@ class TestDistributedCrossAttn:
num_head, num_head,
num_head, num_head,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -329,6 +333,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -329,6 +333,7 @@ class TestDistributedContextParallelSelfAttn:
num_head, num_head,
num_kv_heads, num_kv_heads,
hidden, hidden,
hidden,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -360,6 +365,7 @@ class TestDistributedContextParallelSelfAttn: ...@@ -360,6 +365,7 @@ class TestDistributedContextParallelSelfAttn:
seqlen, seqlen,
seqlen, seqlen,
hidden, hidden,
hidden,
None, None,
) # no SWA for CP ) # no SWA for CP
......
...@@ -106,7 +106,8 @@ def general_dot_product_attention( ...@@ -106,7 +106,8 @@ def general_dot_product_attention(
softmax_out = softmax_out * multiplier softmax_out = softmax_out * multiplier
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value) context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape) context_shape = query.shape[:-1] + (value.shape[-1],)
context = jnp.reshape(context, context_shape)
return context return context
...@@ -294,7 +295,8 @@ class FusedAttnRunner: ...@@ -294,7 +295,8 @@ class FusedAttnRunner:
max_seqlen_kv: int max_seqlen_kv: int
num_heads_q: int num_heads_q: int
num_heads_kv: int num_heads_kv: int
head_dim: int head_dim_qk: int
head_dim_v: int
attn_bias_type: AttnBiasType attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType attn_mask_type: AttnMaskType
dropout_prob: float dropout_prob: float
...@@ -346,6 +348,14 @@ class FusedAttnRunner: ...@@ -346,6 +348,14 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
pytest.skip(
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.is_training, self.is_training,
self.dtype, self.dtype,
...@@ -358,7 +368,8 @@ class FusedAttnRunner: ...@@ -358,7 +368,8 @@ class FusedAttnRunner:
self.num_heads_kv, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_q,
self.max_seqlen_kv, self.max_seqlen_kv,
self.head_dim, self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size, (-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
...@@ -391,13 +402,9 @@ class FusedAttnRunner: ...@@ -391,13 +402,9 @@ class FusedAttnRunner:
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = v_shape = ( k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
self.batch_size, v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
if self.attn_bias_type == AttnBiasType.NO_BIAS: if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None bias_shape = None
...@@ -616,7 +623,7 @@ class FusedAttnRunner: ...@@ -616,7 +623,7 @@ class FusedAttnRunner:
raise ValueError(f"Unknown {self.seq_desc_format=}") raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim) self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
# Setup distributed sharding specs # Setup distributed sharding specs
# Setup shardings for distributed tests # Setup shardings for distributed tests
...@@ -935,9 +942,31 @@ class FusedAttnRunner: ...@@ -935,9 +942,31 @@ class FusedAttnRunner:
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype", "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
[ [
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
),
pytest.param(
2,
2048,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
),
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
),
pytest.param( pytest.param(
2, 2,
2048, 2048,
...@@ -945,11 +974,13 @@ class FusedAttnRunner: ...@@ -945,11 +974,13 @@ class FusedAttnRunner:
12, 12,
12, 12,
64, 64,
32,
jnp.bfloat16, jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS", id="2-2048-1024-12-12-64-32-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
), ),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1003,7 +1034,8 @@ class TestFusedAttn: ...@@ -1003,7 +1034,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1028,7 +1060,8 @@ class TestFusedAttn: ...@@ -1028,7 +1060,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1055,7 +1088,8 @@ class TestFusedAttn: ...@@ -1055,7 +1088,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
...@@ -1077,7 +1111,8 @@ class TestFusedAttn: ...@@ -1077,7 +1111,8 @@ class TestFusedAttn:
s_kv, s_kv,
h_q, h_q,
h_kv, h_kv,
d, d_qk,
d_v,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
dropout_prob, dropout_prob,
......
...@@ -188,7 +188,7 @@ class ReorderStrategy(Enum): ...@@ -188,7 +188,7 @@ class ReorderStrategy(Enum):
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between - DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
mulitple of 2 * cp_size. multiple of 2 * cp_size.
Examples: Examples:
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15]; - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3] - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
...@@ -288,7 +288,8 @@ def is_fused_attn_kernel_available( ...@@ -288,7 +288,8 @@ def is_fused_attn_kernel_available(
kv_num_heads, kv_num_heads,
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim_qk,
head_dim_v,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
): ):
""" """
...@@ -308,7 +309,8 @@ def is_fused_attn_kernel_available( ...@@ -308,7 +309,8 @@ def is_fused_attn_kernel_available(
kv_num_heads, kv_num_heads,
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
head_dim, head_dim_qk,
head_dim_v,
(-1, -1) if window_size is None else window_size, (-1, -1) if window_size is None else window_size,
) )
...@@ -491,7 +493,7 @@ def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): ...@@ -491,7 +493,7 @@ def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
@jax.tree_util.register_pytree_node_class @jax.tree_util.register_pytree_node_class
class SequenceDescriptor: class SequenceDescriptor:
"""A class to descibe the sequences with flexible initialization. """A class to describe the sequences with flexible initialization.
- SequenceDescriptor.from_seqlens - SequenceDescriptor.from_seqlens
For non-THD (non-packed) cases, where each batch has only 1 sequence. For non-THD (non-packed) cases, where each batch has only 1 sequence.
- SequenceDescriptor.from_seqlens_and_offsets - SequenceDescriptor.from_seqlens_and_offsets
......
...@@ -114,7 +114,8 @@ class FusedAttnHelper: ...@@ -114,7 +114,8 @@ class FusedAttnHelper:
kv_num_heads: int kv_num_heads: int
q_max_seqlen: int q_max_seqlen: int
kv_max_seqlen: int kv_max_seqlen: int
head_dim: int head_dim_qk: int
head_dim_v: int
window_size: Tuple[int, int] window_size: Tuple[int, int]
def is_fused_attn_kernel_available(self): def is_fused_attn_kernel_available(self):
...@@ -135,7 +136,8 @@ class FusedAttnHelper: ...@@ -135,7 +136,8 @@ class FusedAttnHelper:
self.kv_num_heads, self.kv_num_heads,
self.q_max_seqlen, self.q_max_seqlen,
self.kv_max_seqlen, self.kv_max_seqlen,
self.head_dim, self.head_dim_qk,
self.head_dim_v,
self.window_size[0], self.window_size[0],
self.window_size[1], self.window_size[1],
) )
...@@ -155,23 +157,49 @@ class FusedAttnHelper: ...@@ -155,23 +157,49 @@ class FusedAttnHelper:
kv_batch_shape = q_batch_shape kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads num_gqa_groups = attn_heads
kv_head_dim = q_head_dim v_head_dim = q_head_dim
assert nqkv == 3 assert nqkv == 3
elif qkv_layout.is_kvpacked(): elif qkv_layout.is_kvpacked():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == v_head_dim
assert nkv == 2 assert nkv == 2
elif qkv_layout.is_separate(): elif qkv_layout.is_separate():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}" *v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
assert (
q_head_dim == k_head_dim
), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
assert (
k_max_seqlen == v_max_seqlen
), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
kv_max_seqlen = k_max_seqlen
assert q_batch_shape == k_batch_shape == v_batch_shape, (
f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
)
assert k_num_gqa_groups == v_num_gqa_groups, (
f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
f" {v_num_gqa_groups}"
)
num_gqa_groups = k_num_gqa_groups
else: else:
raise ValueError(f"Unexpected {qkv_layout=}") raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
assert q_head_dim == kv_head_dim f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
assert q_aval.dtype == k_aval.dtype == v_aval.dtype f" {v_aval.dtype}"
)
return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim) return (
q_batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
)
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -269,11 +297,17 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -269,11 +297,17 @@ class FusedAttnFwdPrimitive(BasePrimitive):
f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}" f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
) )
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) batch_shape,
) q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype # backend determines the softmax buffer shape/dtype
...@@ -289,7 +323,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -289,7 +323,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups, num_gqa_groups,
q_max_seqlen, q_max_seqlen,
kv_max_seqlen, kv_max_seqlen,
head_dim, q_head_dim,
v_head_dim,
config.window_size, config.window_size,
).get_fused_attn_backend() ).get_fused_attn_backend()
...@@ -340,7 +375,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -340,7 +375,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
attn_heads, attn_heads,
num_gqa_groups, num_gqa_groups,
bias_heads, bias_heads,
head_dim, q_head_dim,
v_head_dim,
config.scaling_factor, config.scaling_factor,
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
...@@ -392,9 +428,15 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -392,9 +428,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
""" """
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) batch_shape,
) q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape) input_batch = reduce(operator.mul, batch_shape)
...@@ -433,7 +475,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -433,7 +475,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
attn_heads=attn_heads, attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads, bias_heads=bias_heads,
head_dim=head_dim, qk_head_dim=q_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq, max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor), scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability), dropout_probability=float(config.dropout_probability),
...@@ -711,9 +754,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -711,9 +754,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) batch_shape,
) q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
if config.attn_bias_type == AttnBiasType.NO_BIAS: if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0 bias_batch = bias_heads = 0
...@@ -732,7 +781,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -732,7 +781,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
attn_heads, attn_heads,
num_gqa_groups, num_gqa_groups,
bias_heads, bias_heads,
head_dim, qk_head_dim,
v_head_dim,
config.scaling_factor, config.scaling_factor,
config.dropout_probability, config.dropout_probability,
config.attn_bias_type.value, config.attn_bias_type.value,
...@@ -791,9 +841,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -791,9 +841,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
""" """
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) batch_shape,
) q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape) input_batch = reduce(operator.mul, batch_shape)
...@@ -835,7 +891,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -835,7 +891,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
attn_heads=attn_heads, attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads, bias_heads=bias_heads,
head_dim=head_dim, qk_head_dim=qk_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq, max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor), scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability), dropout_probability=float(config.dropout_probability),
......
...@@ -101,20 +101,20 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -101,20 +101,20 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads, size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left, size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_right); int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right); size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right);
......
...@@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads, size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left, size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
return backend; return backend;
} }
...@@ -117,24 +117,24 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ ...@@ -117,24 +117,24 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
pybind11::tuple GetFusedAttnForwardWorkspaceSizes( pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed // For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v // For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape; auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
...@@ -237,17 +237,17 @@ static void FusedAttnForwardImpl( ...@@ -237,17 +237,17 @@ static void FusedAttnForwardImpl(
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool deterministic, int64_t window_size_left, int64_t window_size_right) { bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
if (is_ragged) { if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
// Memset to 0xF0 for filling large negative numbers // Memset to 0xF0 for filling large negative numbers
...@@ -257,7 +257,7 @@ static void FusedAttnForwardImpl( ...@@ -257,7 +257,7 @@ static void FusedAttnForwardImpl(
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype); auto o_tensor = TensorWrapper(output, o_shape, dtype);
/* Prepare RNG state */ /* Prepare RNG state */
...@@ -265,7 +265,7 @@ static void FusedAttnForwardImpl( ...@@ -265,7 +265,7 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -278,7 +278,7 @@ static void FusedAttnForwardImpl( ...@@ -278,7 +278,7 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
...@@ -287,8 +287,9 @@ static void FusedAttnForwardImpl( ...@@ -287,8 +287,9 @@ static void FusedAttnForwardImpl(
qkv_layout, bias_type, mask_type, window_size_left, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream); window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
...@@ -299,9 +300,9 @@ static void FusedAttnForwardImpl( ...@@ -299,9 +300,9 @@ static void FusedAttnForwardImpl(
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = k_shape; auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
...@@ -327,7 +328,8 @@ static void FusedAttnForwardImpl( ...@@ -327,7 +328,8 @@ static void FusedAttnForwardImpl(
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \ size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \ size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \ size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim"); \ size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim"); \
size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \ size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \ auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \ auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
...@@ -362,9 +364,9 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty ...@@ -362,9 +364,9 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training,
window_size_right); deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -391,33 +393,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, ...@@ -391,33 +393,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right) {
// For qkv_packed // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed // For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v // For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape; auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
...@@ -498,15 +500,15 @@ static void FusedAttnBackwardImpl( ...@@ -498,15 +500,15 @@ static void FusedAttnBackwardImpl(
void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets,
void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) { bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto output_tensor = TensorWrapper(output, output_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
...@@ -520,14 +522,14 @@ static void FusedAttnBackwardImpl( ...@@ -520,14 +522,14 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
if (is_ragged) { if (is_ragged) {
...@@ -543,8 +545,9 @@ static void FusedAttnBackwardImpl( ...@@ -543,8 +545,9 @@ static void FusedAttnBackwardImpl(
bias_type, mask_type, window_size_left, window_size_right, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream); deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
...@@ -564,9 +567,9 @@ static void FusedAttnBackwardImpl( ...@@ -564,9 +567,9 @@ static void FusedAttnBackwardImpl(
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream); deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = k_shape; auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
...@@ -614,9 +617,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T ...@@ -614,9 +617,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype,
is_training, deterministic, window_size_left, window_size_right); wkspace_dtype, is_training, deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
......
...@@ -594,6 +594,12 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -594,6 +594,12 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_kv = seqlen_q seqlen_kv = seqlen_q
else: else:
seqlen_kv = key.shape[sequence_dim] seqlen_kv = key.shape[sequence_dim]
if qkv_layout.is_separate():
head_dim_qk = query.shape[-1]
head_dim_v = value.shape[-1]
else:
head_dim_qk = self.head_dim
head_dim_v = self.head_dim
has_fused_attn_kernel = is_fused_attn_kernel_available( has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
...@@ -608,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -608,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
self.num_gqa_groups, self.num_gqa_groups,
seqlen_q, seqlen_q,
seqlen_kv, seqlen_kv,
self.head_dim, head_dim_qk,
head_dim_v,
self.window_size, self.window_size,
) )
...@@ -621,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -621,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
"Please try to update the cuDNN and TE to the latest version.\n" "Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
) )
dropout_rng = None dropout_rng = None
...@@ -629,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -629,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng = self.make_rng(self.dropout_rng_name) dropout_rng = self.make_rng(self.dropout_rng_name)
if self.scale_factor is None: if self.scale_factor is None:
scale_factor = 1.0 / sqrt(self.head_dim) scale_factor = 1.0 / sqrt(head_dim_qk)
else: else:
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment