Unverified Commit 687697a7 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Add experimental internal used THD(packed) fused attn API (#964)



* Integrate experimental ragged offset
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use per sequence based offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove v/o_seq_offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add FP16 sanity tests and remove forward tests from the automatically run tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance input checks
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Separate fused attn to 2 differnt APIs and add the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add experimental to the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add runtime segments check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove finished TODO
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 7669bf3d
......@@ -15,8 +15,7 @@ from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn_qkvpacked,
fused_attn_kvpacked,
fused_attn,
AttnBiasType,
AttnMaskType,
QKVLayout,
......@@ -120,11 +119,15 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask):
return jnp.mean(
fused_attn_qkvpacked(
qkv,
fused_attn(
(qkv,),
bias,
mask,
None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......@@ -252,12 +255,15 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask):
return jnp.mean(
fused_attn_kvpacked(
q,
kv,
fused_attn(
(q, kv),
None,
mask,
None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
......
......@@ -9,6 +9,7 @@ from math import sqrt
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from flax.linen import combine_masks
......@@ -22,12 +23,12 @@ from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
fused_attn_qkvpacked,
fused_attn_kvpacked,
QKVFormat,
fused_attn,
fused_attn_thd,
get_qkv_format,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
......@@ -102,7 +103,7 @@ def is_causal_mask(mask: AttnMaskType):
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
"""
Create inverse padded causal mask where `True` means allowing the corresponding
position to participate in attention and `False` means masking out that position.
......@@ -110,31 +111,75 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(inv_causal_mask, inv_padding_mask)
return inv_causal_mask
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
def make_mask(
q_token: ArrayLike,
kv_token: ArrayLike,
segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike,
attn_mask_type: AttnMaskType,
) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
masking out the corresponding position and a `False` value means allowing
that position to participate in attention.
"""
inv_mask = make_attention_mask(
q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)
if is_causal_mask(attn_mask_type):
inv_mask = make_decoder_mask(q_token, kv_token)
else:
inv_mask = make_attention_mask(q_token > 0, kv_token > 0)
inv_causal_mask = make_causal_mask(q_token, kv_token)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
if segment_pad_q is not None and segment_pad_kv is not None:
inv_pad_mask = make_attention_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
)
inv_mask = combine_masks(inv_pad_mask, inv_mask)
mask = jnp.logical_not(inv_mask)
return mask
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
def get_seqlens_and_offsets(segment_ids, segment_pad):
batch, max_seqlen = segment_ids.shape
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
seqlens = seqlens_with_zero[..., 1:]
def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = jnp.ones((x.shape[0], 1), dtype=bool)
same_as_previous = jnp.hstack((first_column, same_as_previous))
return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
same_as_previous
).squeeze(-1)
offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
if segment_pad is not None:
segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
else:
output = jnp.insert(seqlens, -1, values=0, axis=-1)
return output, offsets
@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
"""Use JIT to speed up the verifications"""
primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
output = general_dot_product_attention(
query,
key,
......@@ -150,29 +195,43 @@ def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
return output.astype(query.dtype)
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
def customcall_fused_dpa(
query,
key,
value,
bias,
mask,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
):
"""
TE customcall dot product attention implementation
"""
attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
qkv_layout = kwargs.pop("qkv_layout")
qkv_layout = kwargs["qkv_layout"]
is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
match qkv_layout:
case QKVLayout.BS3HD:
case QKVLayout.BS3HD | QKVLayout.T3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD:
qkv_args = (qkv,)
case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype(
query.dtype
)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(
query.dtype
)
qkv_args = (query, kv)
case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
if not is_thd:
kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs
).astype(query.dtype)
class BiasShape(Enum):
......@@ -207,11 +266,18 @@ class FusedAttnRunner:
bias_shape: BiasShape
def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
# TODO(rewang): probably adds this in is_fused_attn_available
if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
AttnMaskType.PADDING_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
]:
pytest.skip("THD format requires padding masks.")
if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD:
if self.num_heads_q != self.num_heads_kv:
pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.")
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper(
self.dtype,
......@@ -293,10 +359,78 @@ class FusedAttnRunner:
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return valid_len, tokens
return tokens, jnp.logical_not(tokens)
def generate_random_segment_ids(
batch_size, sequence_length, num_segments, seed, with_segment_pad=True
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
# Not include paddings
max_segment_size = sequence_length // num_segments
for i in range(batch_size):
current_pos = 0
segment_id = 1
for _ in range(num_segments):
segment_size = rng.integers(1, max_segment_size + 1)
if current_pos + segment_size > sequence_length:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
segment_pad[i, current_pos:sequence_length] = 1
return segment_ids, segment_pad
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.num_segments_per_seq = 3
self.token_q, self.segment_pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
# TODO(rewang): Check if qkvpacked supported different q/kv
# TODO(rewang): Causal with different q/kv segment_id fails
if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
self.token_kv = self.token_q
self.segment_pad_kv = self.segment_pad_q
else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024
)
self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv
else:
self.num_segments_per_seq = 1
self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.segment_pad_q = self.segment_pad_kv = None
self.mask = make_mask(
self.token_q,
self.token_kv,
self.segment_pad_q,
self.segment_pad_kv,
self.attn_mask_type,
)
self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
self.token_q, self.segment_pad_q
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
self.token_kv, self.segment_pad_kv
)
self.mask_for_customcall = None # THD format doesn't support mask
else:
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
self.mask_for_customcall = self.mask
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
......@@ -307,7 +441,19 @@ class FusedAttnRunner:
"""
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
......@@ -315,17 +461,19 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq,
}
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
reference_out = jax_dpa(*args, **kwargs)
if self.is_training and self.dropout_prob > 0.0:
return
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
......@@ -341,14 +489,28 @@ class FusedAttnRunner:
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.valid_len_q * self.num_heads_q
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type):
gradient_multiplier /= 10
# Keep only valid result for the gradient
ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
......@@ -356,6 +518,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq,
}
# We can compute dBias only for the [1, h, s, s] layout
......@@ -377,54 +540,56 @@ class FusedAttnRunner:
)
)
primitive_out, primitive_dgrad = jitted_primitive(*args)
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.0:
return
assert_allclose(
primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
def check_dqkv(primitive, reference, pad):
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive, reference, pad)
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
# Convert the outputs to float32 for the elementwise comparison
primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
check_dqkv(primitive_dq, reference_dq, self.pad_q)
check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.pad_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])
primitive_dbias = primitive_dgrad[3]
reference_dbias = reference_dgrad[3]
# Assume all batch has the same actual_seqlen, probably needs to extend the tests
bias_mask = self.mask[0, 0]
# Assert all masked dbias are 0s
assert_allclose(
primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]),
jnp.where(bias_mask, primitive_dbias, 0),
jnp.zeros_like(primitive_dbias),
dtype=self.dtype,
)
# dbias padded part
assert_allclose(
primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
reference_dbias[..., self.valid_len_q :, self.valid_len_kv :],
jnp.where(bias_mask, primitive_dbias, 0),
jnp.where(bias_mask, reference_dbias, 0),
dtype=self.dtype,
)
# dbias valid part
assert_allclose(
primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv],
reference_dbias[..., : self.valid_len_q, : self.valid_len_kv],
jnp.where(bias_mask, 0, primitive_dbias),
jnp.where(bias_mask, 0, reference_dbias),
dtype=self.dtype,
)
......@@ -454,24 +619,21 @@ class FusedAttnRunner:
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d",
"b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"),
pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"),
pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"),
pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"),
pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"),
pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
],
)
@pytest.mark.parametrize(
......@@ -494,7 +656,7 @@ class TestFusedAttn:
pytest.param(False, id="INFERENCE"),
],
)
def test_forward(
def _test_forward(
b,
s_q,
s_kv,
......@@ -511,6 +673,8 @@ class TestFusedAttn:
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
runner = FusedAttnRunner(
b,
......
......@@ -5,6 +5,7 @@
from enum import Enum
from functools import partial
from typing import Optional, Tuple
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
......@@ -12,6 +13,8 @@ import jax.numpy as jnp
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
from . import cpp_extensions as tex
......@@ -43,11 +46,42 @@ class AttnMaskType(Enum):
class QKVLayout(Enum):
"""QKV layout"""
"""
BSHD Format:
- BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
- BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
- BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
- T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
- THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
- THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
"""
BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
T3HD = NVTE_QKV_Layout.NVTE_T3HD
THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD
class QKVFormat(Enum):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
BSHD: q,k,v memory layout with [b, s, ..., h, d]
THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
"""
SBHD = NVTE_QKV_Format.NVTE_SBHD
BSHD = NVTE_QKV_Format.NVTE_BSHD
THD = NVTE_QKV_Format.NVTE_THD
def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
"""
return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
def canonicalize_attn_mask_type(attn_mask_type: str):
......@@ -102,414 +136,357 @@ def is_fused_attn_kernel_available(
).is_fused_attn_kernel_available()
def fused_attn_qkvpacked(
qkv: jnp.ndarray,
bias: jnp.ndarray | None,
mask: jnp.ndarray,
seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Fused attention with the qkvpacked inputs
"""
output = _fused_attn_qkvpacked(
qkv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return output
def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
match qkv_layout:
case QKVLayout.BS3HD | QKVLayout.T3HD:
assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
batch, q_max_seqlen, *_ = qkv[0].shape
kv_max_seqlen = q_max_seqlen
case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
batch, q_max_seqlen, *_ = qkv[0].shape
kv_max_seqlen = qkv[1].shape[1]
case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
batch, q_max_seqlen, *_ = qkv[0].shape
kv_max_seqlen = qkv[1].shape[1]
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return batch, q_max_seqlen, kv_max_seqlen
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _fused_attn_qkvpacked(
qkv: jnp.ndarray,
bias: jnp.ndarray | None,
mask: jnp.ndarray,
seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
output, _ = _fused_attn_fwd_qkvpacked_rule(
qkv,
bias,
mask,
seed,
attn_bias_type,
attn_mask_type,
scaling_factor,
dropout_probability,
is_training,
)
return output
def _fused_attn_fwd_qkvpacked_rule(
qkv: jnp.ndarray,
bias: jnp.ndarray | None,
mask: jnp.ndarray,
seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = tex.fused_attn_fwd_qkvpacked(
qkv,
bias,
actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _fused_attn_bwd_qkvpacked_rule(
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked(
qkv,
bias,
softmax_aux,
rng_state,
output,
dz,
actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
def fused_attn_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
def fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
mask: Optional[jnp.ndarray],
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Fused attention with the kvpacked inputs
Perform non-THD (non-packed) cuDNN fused attention.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
mask (Optional[jnp.ndarray]):
An optional mask tensor to mask out the attention scores, `True` means mask out.
Intra-sequence padding is not valid. The padded tokens can only on the right-most.
Otherwise the results will be wrong.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
output = _fused_attn_kvpacked(
q,
kv,
bias,
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def _fused_attn_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
output, _ = _fused_attn_fwd_kvpacked_rule(
q,
kv,
bias,
mask,
seed,
attn_bias_type,
attn_mask_type,
scaling_factor,
dropout_probability,
is_training,
)
return output
def _fused_attn_fwd_kvpacked_rule(
q,
kv,
bias,
mask,
seed,
attn_bias_type,
attn_mask_type,
scaling_factor,
dropout_probability,
is_training,
):
assert (
get_qkv_format(qkv_layout) != QKVFormat.THD
), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
# convert the mask to seqlens, mask doesn't support ragged offsets
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = kv.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = tex.fused_attn_fwd_kvpacked(
q,
kv,
output = _fused_attn(
qkv,
bias,
q_actual_seqlen,
kv_actual_seqlen,
q_seq_lens,
kv_seq_lens,
None,
None,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _fused_attn_bwd_kvpacked_rule(
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked(
q,
kv,
bias,
softmax_aux,
rng_state,
output,
dz,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=1,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_kv, grad_bias, None, None
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
return output
def fused_attn(
q: jnp.ndarray,
k: jnp.ndarray,
v: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
def fused_attn_thd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
q_seq_lens: jnp.ndarray,
kv_seq_lens: jnp.ndarray,
q_seq_offsets: jnp.ndarray,
kv_seq_offsets: jnp.ndarray,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int = 1,
):
"""
Dot product attention with the seperated query, key, value
(Experimental) Perform THD (packed) cuDNN fused attention.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
q_seqlen (jnp.ndarray):
Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
padded with -1.
kv_seqlen (jnp.ndarray):
Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
are padded with -1.
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
Examples:
>>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
>>> b, s, h, d = 2, 4, 12, 64
>>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
>>> # 3 segments in first seq, 2 segments in second seq
>>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
>>> # seq_offsets need to include the end offset of the last segments
>>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
>>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
q_seq_offsets, kv_seq_offsets, None,
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3)
"""
assert (
get_qkv_format(qkv_layout) == QKVFormat.THD
), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
assert q_seq_lens.shape == (batch, q_max_seqlen)
assert kv_seq_lens.shape == (batch, kv_max_seqlen)
assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
output = _fused_attn(
q,
k,
v,
qkv,
bias,
mask,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13))
def _fused_attn(
q: jnp.ndarray,
k: jnp.ndarray,
v: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
q_seq_lens: jnp.ndarray,
kv_seq_lens: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
):
output, _ = _fused_attn_fwd_rule(
q,
k,
v,
qkv,
bias,
mask,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
)
return output
def _fused_attn_fwd_rule(
q,
k,
v,
qkv,
bias,
mask,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = tex.fused_attn_fwd(
q,
k,
v,
qkv,
bias,
q_actual_seqlen,
kv_actual_seqlen,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
return output, (
q,
k,
v,
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
softmax_aux,
rng_state,
output,
q_actual_seqlen,
kv_actual_seqlen,
)
def _fused_attn_bwd_rule(
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
ctx,
dz,
):
q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd(
q,
k,
v,
(
qkv,
bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
softmax_aux,
rng_state,
output,
) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd(
qkv,
bias,
softmax_aux,
rng_state,
output,
dz,
q_actual_seqlen,
kv_actual_seqlen,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_q, grad_k, grad_v, grad_bias, None, None
return grad_qkv, grad_bias, None, None, None, None, None
_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
......@@ -5,6 +5,7 @@
from dataclasses import dataclass
from functools import partial, reduce
import operator
from typing import Optional, Tuple
import warnings
import jax.numpy as jnp
......@@ -18,7 +19,9 @@ from transformer_engine.transformer_engine_jax import (
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Fused_Attn_Backend,
nvte_get_qkv_format,
)
from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
......@@ -37,10 +40,6 @@ from ..sharding import (
__all__ = [
"FusedAttnHelper",
"fused_attn_fwd_qkvpacked",
"fused_attn_bwd_qkvpacked",
"fused_attn_fwd_kvpacked",
"fused_attn_bwd_kvpacked",
"fused_attn_fwd",
"fused_attn_bwd",
]
......@@ -88,18 +87,18 @@ class FusedAttnHelper:
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
kv_head_dim = q_head_dim
assert nqkv == 3
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
assert nkv == 2
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape
......@@ -158,8 +157,9 @@ def generate_cu_seqlen(actual_seqlen):
"""
Generating cumsum seqlen for a batch
"""
cu_seqlen = jnp.cumsum(actual_seqlen)
cu_seqlen = jnp.hstack((0, cu_seqlen))
cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1)
cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen)
cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1)
return cu_seqlen
......@@ -170,7 +170,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12)
impl_static_args = (9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
......@@ -182,6 +182,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias_aval,
q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
seed_aval,
*,
attn_bias_type,
......@@ -190,6 +192,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
"""
Fused attention fwd abstract
......@@ -227,7 +230,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
......@@ -266,6 +269,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
is_training,
max_segments_per_seq,
)
wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -292,6 +296,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
*,
attn_bias_type,
......@@ -300,11 +306,22 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
operands = [
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
......@@ -337,6 +354,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups,
bias_heads,
head_dim,
max_segments_per_seq,
wkspace_aval.size,
scaling_factor,
dropout_probability,
......@@ -360,6 +378,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
attn_bias_type,
attn_mask_type,
......@@ -367,11 +387,64 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
assert FusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=-1)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
offsets_2d = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
offsets,
)
return offsets_2d
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_T2HD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)
# Set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q,
......@@ -380,6 +453,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
......@@ -387,6 +462,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
return output, softmax_aux, rng_state
......@@ -401,6 +477,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
......@@ -416,6 +493,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
),
out_bdims,
)
......@@ -428,29 +506,30 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
):
del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos
del dropout_probability, is_training, max_segments_per_seq, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
# q_spec = (...batch, q_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
)
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])
)
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
......@@ -470,6 +549,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
......@@ -489,6 +569,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
return mesh, impl, out_shardings, arg_shardings
......@@ -503,7 +584,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward"
multiple_results = True
impl_static_args = (10, 11, 12, 13, 14, 15)
impl_static_args = (12, 13, 14, 15, 16, 17, 18)
inner_primitive = None
outer_primitive = None
......@@ -517,8 +598,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
rng_state_aval,
output_aval,
doutput_aval,
q_cu_seqlen_aval,
kv_cu_seqlen_aval,
q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
*,
attn_bias_type,
attn_mask_type,
......@@ -526,6 +609,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
"""
Fused attention bwd abstract
......@@ -538,7 +622,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
......@@ -567,6 +651,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
is_training,
max_segments_per_seq,
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
......@@ -600,6 +685,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
*,
attn_bias_type,
attn_mask_type,
......@@ -607,6 +694,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
"""
Fused attention bwd lowering rules
......@@ -622,6 +710,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
......@@ -656,6 +746,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
num_gqa_groups,
bias_heads,
head_dim,
max_segments_per_seq,
wkspace_aval.size,
scaling_factor,
dropout_probability,
......@@ -683,17 +774,73 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
assert FusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen)
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
def _fix_len_take(x, condition):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
# TODO(rewang): try indices_are_sorted
y = jnp.take(x, indices, fill_value=-1)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
offsets_2d = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
offsets,
)
return offsets_2d
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_T2HD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)
# Set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
......@@ -706,12 +853,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
return dq, dk, dv, dbias
......@@ -726,6 +876,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
):
check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
......@@ -741,6 +892,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
),
out_bdims,
)
......@@ -753,11 +905,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
):
del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor
del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq
del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
......@@ -777,6 +930,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
mesh,
arg_infos,
result_infos,
......@@ -794,7 +948,18 @@ class FusedAttnBwdPrimitive(BasePrimitive):
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
......@@ -807,12 +972,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
......@@ -825,245 +993,182 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive)
def fused_attn_fwd_qkvpacked(
qkv: jnp.ndarray,
bias: jnp.ndarray,
seqlen: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
_not_used = jnp.zeros(0, qkv.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(
qkv,
_not_used,
_not_used,
bias,
seqlen,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
def fused_attn_bwd_qkvpacked(
qkv: jnp.ndarray,
bias: jnp.ndarray,
softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE self fused attention bwd
Return the gradients of self fused attention with packed qkv input
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
dummy_input = jnp.zeros(0, dtype=qkv.dtype)
dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
qkv,
dummy_input,
dummy_input,
bias,
softmax_aux,
rng_state,
output,
doutput,
seqlen,
seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return dqkv, dbias
def fused_attn_fwd_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
seed: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
seed: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
max_segments_per_seq: int,
) -> jnp.ndarray:
"""
Wrapper for TE fused attention fwd with kvpacked inputs
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Perform the forward pass of with cuDNN fused attention implementations.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
assert (q_seq_offsets is None) == (
kv_seq_offsets is None
), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
bias = jnp.zeros(0, dtype=qkv[0].dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(
q,
kv,
jnp.zeros(0, q.dtype),
*qkv_for_primitive,
bias,
q_seqlen,
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
def fused_attn_bwd_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
def fused_attn_bwd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
):
"""
Wrapper for TE fused attention bwd with kvpacked inputs
Return the gradients of fused attention with packed kv input
Perform the backward pass of the cuDNN fused attention implementations.
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
used in the forward pass. It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
doutput (jnp.ndarray): The gradient with respect to the output.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
same format as the input `qkv`.
- The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
dummy_input = jnp.zeros(0, q.dtype)
dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
q,
kv,
dummy_input,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return dq, dkv, dbias
def fused_attn_fwd(
q: jnp.ndarray,
k: jnp.ndarray,
v: jnp.ndarray,
bias: jnp.ndarray,
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
assert (q_seq_offsets is None) == (
kv_seq_offsets is None
), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
bias = jnp.zeros(0, dtype=qkv[0].dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(
q,
k,
v,
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
def fused_attn_bwd(
q: jnp.ndarray,
k: jnp.ndarray,
v: jnp.ndarray,
bias: jnp.ndarray,
softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE fused attention bwd
Return the gradients of fused attention with seperated query, key, value tensors
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(
q,
k,
v,
*qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind(
*qkv_for_primitive,
bias,
softmax_aux,
rng_state,
......@@ -1071,10 +1176,14 @@ def fused_attn_bwd(
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
)
return tuple(qkv_grads[: len(qkv)]), bias_grad
......@@ -137,6 +137,7 @@ struct CustomCallFusedAttnDescriptor {
size_t num_gqa_groups;
size_t bias_heads;
size_t head_dim;
size_t max_segments_per_seq;
size_t wkspace_size;
float scaling_factor;
float dropout_probability;
......@@ -151,9 +152,9 @@ struct CustomCallFusedAttnDescriptor {
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training);
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training);
// Transpose
......@@ -249,7 +250,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......@@ -257,7 +259,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -104,7 +104,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
......@@ -128,128 +129,50 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
}
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype,
bool is_training) {
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto bias_shape = std::vector<size_t>{1, attn_heads, q_max_seqlen, kv_max_seqlen};
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
}
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
......@@ -260,18 +183,23 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6];
void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];
/* Output buffer from XLA */
void *output = buffers[7];
void *softmax_aux = buffers[8];
void *rng_state = buffers[9];
void *workspace = buffers[10];
void *output = buffers[9];
void *softmax_aux = buffers[10];
void *rng_state = buffers[11];
void *workspace = buffers[12];
/* Descriptor */
auto input_batch = descriptor.input_batch;
......@@ -286,8 +214,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype;
auto is_training = descriptor.is_training;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
/* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -296,14 +225,33 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
cudaMemsetAsync(output, 0,
input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
......@@ -323,19 +271,18 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, descriptor.scaling_factor,
&aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
......@@ -345,11 +292,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
......@@ -359,13 +305,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......@@ -377,46 +322,89 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto v_shape = k_shape;
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
// It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
}
auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
......@@ -426,6 +414,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3];
......@@ -435,11 +426,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */
/* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */
void *dbias = buffers[13];
void *workspace = buffers[14];
/* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */
void *dbias = buffers[15];
void *workspace = buffers[16];
/* Descriptor */
auto input_batch = descriptor.input_batch;
......@@ -454,8 +447,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
/* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -463,13 +456,30 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors;
......@@ -486,42 +496,54 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10];
auto dqkv = buffers[12];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
if (is_ragged) {
size_t dqkv_size =
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1];
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto dq = buffers[10];
auto dq = buffers[12];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv = buffers[11];
auto dkv = buffers[13];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
if (is_ragged) {
size_t dq_size =
std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
size_t dkv_size =
std::accumulate(kv_shape.cbegin(), kv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, dkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype);
......@@ -531,21 +553,31 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto v = buffers[2];
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq = buffers[10];
auto dq = buffers[12];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk = buffers[11];
auto dk = buffers[13];
auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv = buffers[12];
auto dv = buffers[14];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) {
size_t dq_size =
std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
size_t dk_size =
std::accumulate(k_shape.cbegin(), k_shape.cend(), 1, std::multiplies<size_t>());
size_t dv_size = dk_size;
cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, dk_size * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, dv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
......
......@@ -66,13 +66,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
bool is_training) {
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout,
dtype, wkspace_dtype, is_training});
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
}
} // namespace jax
......
......@@ -67,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......@@ -92,7 +93,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD);
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU)
......
......@@ -45,5 +45,29 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
NVTE_CHECK_CUDA(cudaGetLastError());
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax
} // namespace transformer_engine
......@@ -28,6 +28,8 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
......
......@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..attention import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from ..attention import fused_attn
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -268,6 +268,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor
del self.scale_factor
# TODO(rewang): integrate THD format
if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d]
......@@ -277,13 +278,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
qkv_packed = query
if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_qkvpacked(
qkv_packed,
x = fused_attn(
(qkv_packed,),
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -298,14 +300,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_kvpacked(
query,
kv_packed,
x = fused_attn(
(query, kv_packed),
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......@@ -316,14 +318,13 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3])
x = fused_attn(
query,
key,
value,
(query, key, value),
bias,
mask,
seed,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor,
dropout_probability=self.attention_dropout,
is_training=not deterministic,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment