Unverified Commit 1ddfa0c6 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (#1851)



* Add support for Fused Attn MLA head_dim_qk != head_dim_v
	Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v
	Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval()
	Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v
	Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v
	Modify DotProductAttention call() to extract head dims separately for qk and v
	Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API
	Add test case for head_dim_qk != head_dim_v (failing)
	Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix context dims in general DPA in test_fused_attn
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Fix dim for output tensor by replacing with v head dim rather than q head dim
Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Modify the fused attn jax unit test case for head dim qk != head dim v
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Fix usage of is_fused_attn signature in distributed tests
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Remove unnecessary assert
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 71c76b6b
......@@ -80,6 +80,7 @@ class TestDistributedSelfAttn:
seqlen,
seqlen,
hidden,
hidden,
None, # no window
):
pytest.skip("No FusedAttn backend found")
......@@ -99,6 +100,7 @@ class TestDistributedSelfAttn:
num_head,
num_head,
hidden,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -227,6 +229,7 @@ class TestDistributedCrossAttn:
seqlen,
seqlen,
hidden,
hidden,
None, # no window
):
pytest.skip("No FusedAttn backend found")
......@@ -239,6 +242,7 @@ class TestDistributedCrossAttn:
num_head,
num_head,
hidden,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -329,6 +333,7 @@ class TestDistributedContextParallelSelfAttn:
num_head,
num_kv_heads,
hidden,
hidden,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -360,6 +365,7 @@ class TestDistributedContextParallelSelfAttn:
seqlen,
seqlen,
hidden,
hidden,
None,
) # no SWA for CP
......
......@@ -106,7 +106,8 @@ def general_dot_product_attention(
softmax_out = softmax_out * multiplier
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape)
context_shape = query.shape[:-1] + (value.shape[-1],)
context = jnp.reshape(context, context_shape)
return context
......@@ -294,7 +295,8 @@ class FusedAttnRunner:
max_seqlen_kv: int
num_heads_q: int
num_heads_kv: int
head_dim: int
head_dim_qk: int
head_dim_v: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
dropout_prob: float
......@@ -346,6 +348,14 @@ class FusedAttnRunner:
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
pytest.skip(
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
......@@ -358,7 +368,8 @@ class FusedAttnRunner:
self.num_heads_kv,
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
self.head_dim_qk,
self.head_dim_v,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
......@@ -391,13 +402,9 @@ class FusedAttnRunner:
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (
self.batch_size,
self.max_seqlen_kv,
self.num_heads_kv,
self.head_dim,
)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
......@@ -616,7 +623,7 @@ class FusedAttnRunner:
raise ValueError(f"Unknown {self.seq_desc_format=}")
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
# Setup distributed sharding specs
# Setup shardings for distributed tests
......@@ -935,9 +942,31 @@ class FusedAttnRunner:
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d, dtype",
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
[
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
),
pytest.param(
2,
2048,
1024,
12,
12,
64,
64,
jnp.bfloat16,
id="2-2048-1024-12-12-64-64-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
),
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
),
pytest.param(
2,
2048,
......@@ -945,11 +974,13 @@ class FusedAttnRunner:
12,
12,
64,
32,
jnp.bfloat16,
id="2-2048-1024-12-12-64-BF16-CROSS",
id="2-2048-1024-12-12-64-32-BF16-CROSS",
),
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
],
)
@pytest.mark.parametrize(
......@@ -1003,7 +1034,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -1028,7 +1060,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -1055,7 +1088,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......@@ -1077,7 +1111,8 @@ class TestFusedAttn:
s_kv,
h_q,
h_kv,
d,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
dropout_prob,
......
......@@ -188,7 +188,7 @@ class ReorderStrategy(Enum):
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
mulitple of 2 * cp_size.
multiple of 2 * cp_size.
Examples:
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
......@@ -288,7 +288,8 @@ def is_fused_attn_kernel_available(
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
head_dim_qk,
head_dim_v,
window_size: Optional[Tuple[int, int]] = None,
):
"""
......@@ -308,7 +309,8 @@ def is_fused_attn_kernel_available(
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
head_dim_qk,
head_dim_v,
(-1, -1) if window_size is None else window_size,
)
......@@ -491,7 +493,7 @@ def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
@jax.tree_util.register_pytree_node_class
class SequenceDescriptor:
"""A class to descibe the sequences with flexible initialization.
"""A class to describe the sequences with flexible initialization.
- SequenceDescriptor.from_seqlens
For non-THD (non-packed) cases, where each batch has only 1 sequence.
- SequenceDescriptor.from_seqlens_and_offsets
......
......@@ -114,7 +114,8 @@ class FusedAttnHelper:
kv_num_heads: int
q_max_seqlen: int
kv_max_seqlen: int
head_dim: int
head_dim_qk: int
head_dim_v: int
window_size: Tuple[int, int]
def is_fused_attn_kernel_available(self):
......@@ -135,7 +136,8 @@ class FusedAttnHelper:
self.kv_num_heads,
self.q_max_seqlen,
self.kv_max_seqlen,
self.head_dim,
self.head_dim_qk,
self.head_dim_v,
self.window_size[0],
self.window_size[1],
)
......@@ -155,23 +157,49 @@ class FusedAttnHelper:
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
kv_head_dim = q_head_dim
v_head_dim = q_head_dim
assert nqkv == 3
elif qkv_layout.is_kvpacked():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == v_head_dim
assert nkv == 2
elif qkv_layout.is_separate():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}"
*k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
*v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
assert (
q_head_dim == k_head_dim
), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
assert (
k_max_seqlen == v_max_seqlen
), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
kv_max_seqlen = k_max_seqlen
assert q_batch_shape == k_batch_shape == v_batch_shape, (
f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
)
assert k_num_gqa_groups == v_num_gqa_groups, (
f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
f" {v_num_gqa_groups}"
)
num_gqa_groups = k_num_gqa_groups
else:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_batch_shape == kv_batch_shape
assert q_head_dim == kv_head_dim
assert q_aval.dtype == k_aval.dtype == v_aval.dtype
return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)
assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
f" {v_aval.dtype}"
)
return (
q_batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
)
@dataclass(frozen=True)
......@@ -269,11 +297,17 @@ class FusedAttnFwdPrimitive(BasePrimitive):
f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
)
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
......@@ -289,7 +323,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups,
q_max_seqlen,
kv_max_seqlen,
head_dim,
q_head_dim,
v_head_dim,
config.window_size,
).get_fused_attn_backend()
......@@ -340,7 +375,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
q_head_dim,
v_head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type.value,
......@@ -392,9 +428,15 @@ class FusedAttnFwdPrimitive(BasePrimitive):
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
......@@ -433,7 +475,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
qk_head_dim=q_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
......@@ -711,9 +754,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
......@@ -732,7 +781,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
qk_head_dim,
v_head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type.value,
......@@ -791,9 +841,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
......@@ -835,7 +891,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
qk_head_dim=qk_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
......
......@@ -101,20 +101,20 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left,
int64_t window_size_right);
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
......
......@@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
return backend;
}
......@@ -117,24 +117,24 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
......@@ -237,17 +237,17 @@ static void FusedAttnForwardImpl(
void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux,
void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads,
size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
size_t qk_head_dim, size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
if (is_ragged) {
auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim;
auto output_size = input_batch * q_max_seqlen * attn_heads * v_head_dim;
cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream);
// Memset to 0xF0 for filling large negative numbers
......@@ -257,7 +257,7 @@ static void FusedAttnForwardImpl(
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
/* Prepare RNG state */
......@@ -265,7 +265,7 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -278,7 +278,7 @@ static void FusedAttnForwardImpl(
/* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
......@@ -287,8 +287,9 @@ static void FusedAttnForwardImpl(
qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked(
......@@ -299,9 +300,9 @@ static void FusedAttnForwardImpl(
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
......@@ -327,7 +328,8 @@ static void FusedAttnForwardImpl(
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t head_dim = get_attr_value<int64_t>(attrs, "head_dim"); \
size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim"); \
size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
......@@ -362,9 +364,9 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(),
softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(),
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, is_training,
deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check();
}
......@@ -391,33 +393,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
......@@ -498,15 +500,15 @@ static void FusedAttnBackwardImpl(
void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets,
void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace,
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t qk_head_dim,
size_t v_head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic, int64_t window_size_left, int64_t window_size_right) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, v_head_dim};
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
......@@ -520,14 +522,14 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
/* Call the underly NVTE API */
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype);
if (is_ragged) {
......@@ -543,8 +545,9 @@ static void FusedAttnBackwardImpl(
bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape =
std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
......@@ -564,9 +567,9 @@ static void FusedAttnBackwardImpl(
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
auto v_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
......@@ -614,9 +617,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(),
dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(),
workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen,
attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size,
scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, max_segments_per_seq,
wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype,
wkspace_dtype, is_training, deterministic, window_size_left, window_size_right);
return ffi_with_cuda_error_check();
}
......
......@@ -594,6 +594,12 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_kv = seqlen_q
else:
seqlen_kv = key.shape[sequence_dim]
if qkv_layout.is_separate():
head_dim_qk = query.shape[-1]
head_dim_v = value.shape[-1]
else:
head_dim_qk = self.head_dim
head_dim_v = self.head_dim
has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
......@@ -608,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
self.num_gqa_groups,
seqlen_q,
seqlen_kv,
self.head_dim,
head_dim_qk,
head_dim_v,
self.window_size,
)
......@@ -621,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
"Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
)
dropout_rng = None
......@@ -629,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng = self.make_rng(self.dropout_rng_name)
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(self.head_dim)
scale_factor = 1.0 / sqrt(head_dim_qk)
else:
scale_factor = self.scale_factor
del self.scale_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment