Unverified Commit 687697a7 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Add experimental internal used THD(packed) fused attn API (#964)



* Integrate experimental ragged offset
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Use per sequence based offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Format
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove v/o_seq_offsets
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add FP16 sanity tests and remove forward tests from the automatically run tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Enhance input checks
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Separate fused attn to 2 differnt APIs and add the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add experimental to the docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix lint
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add runtime segments check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove finished TODO
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 7669bf3d
...@@ -15,8 +15,7 @@ from utils import make_causal_mask, make_self_mask ...@@ -15,8 +15,7 @@ from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
is_fused_attn_kernel_available, is_fused_attn_kernel_available,
fused_attn_qkvpacked, fused_attn,
fused_attn_kvpacked,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout, QKVLayout,
...@@ -120,11 +119,15 @@ class TestDistributedSelfAttn: ...@@ -120,11 +119,15 @@ class TestDistributedSelfAttn:
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
fused_attn_qkvpacked( fused_attn(
qkv, (qkv,),
bias, bias,
mask, mask,
None, None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
...@@ -252,12 +255,15 @@ class TestDistributedCrossAttn: ...@@ -252,12 +255,15 @@ class TestDistributedCrossAttn:
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
fused_attn_kvpacked( fused_attn(
q, (q, kv),
kv,
None, None,
mask, mask,
None, None,
None,
None,
None,
None,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
......
...@@ -9,6 +9,7 @@ from math import sqrt ...@@ -9,6 +9,7 @@ from math import sqrt
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
import pytest import pytest
from flax.linen import combine_masks from flax.linen import combine_masks
...@@ -22,12 +23,12 @@ from transformer_engine.jax.attention import ( ...@@ -22,12 +23,12 @@ from transformer_engine.jax.attention import (
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout, QKVLayout,
fused_attn_qkvpacked, QKVFormat,
fused_attn_kvpacked,
fused_attn, fused_attn,
fused_attn_thd,
get_qkv_format,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose from utils import assert_allclose
...@@ -102,7 +103,7 @@ def is_causal_mask(mask: AttnMaskType): ...@@ -102,7 +103,7 @@ def is_causal_mask(mask: AttnMaskType):
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
""" """
Create inverse padded causal mask where `True` means allowing the corresponding Create inverse padded causal mask where `True` means allowing the corresponding
position to participate in attention and `False` means masking out that position. position to participate in attention and `False` means masking out that position.
...@@ -110,31 +111,75 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: ...@@ -110,31 +111,75 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0) return inv_causal_mask
return combine_masks(inv_causal_mask, inv_padding_mask)
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array: def make_mask(
q_token: ArrayLike,
kv_token: ArrayLike,
segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike,
attn_mask_type: AttnMaskType,
) -> Array:
""" """
Create attention mask based on mask type. A `True` value in the mask means Create attention mask based on mask type. A `True` value in the mask means
masking out the corresponding position and a `False` value means allowing masking out the corresponding position and a `False` value means allowing
that position to participate in attention. that position to participate in attention.
""" """
inv_mask = make_attention_mask(
q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)
if is_causal_mask(attn_mask_type): if is_causal_mask(attn_mask_type):
inv_mask = make_decoder_mask(q_token, kv_token) inv_causal_mask = make_causal_mask(q_token, kv_token)
else: inv_mask = combine_masks(inv_causal_mask, inv_mask)
inv_mask = make_attention_mask(q_token > 0, kv_token > 0) if segment_pad_q is not None and segment_pad_kv is not None:
inv_pad_mask = make_attention_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
)
inv_mask = combine_masks(inv_pad_mask, inv_mask)
mask = jnp.logical_not(inv_mask) mask = jnp.logical_not(inv_mask)
return mask return mask
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): def get_seqlens_and_offsets(segment_ids, segment_pad):
batch, max_seqlen = segment_ids.shape
bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
seqlens = seqlens_with_zero[..., 1:]
def _find_offsets(x):
same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
first_column = jnp.ones((x.shape[0], 1), dtype=bool)
same_as_previous = jnp.hstack((first_column, same_as_previous))
return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
same_as_previous
).squeeze(-1)
offsets = _find_offsets(segment_ids)
offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
if segment_pad is not None:
segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
else:
output = jnp.insert(seqlens, -1, values=0, axis=-1)
return output, offsets
@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
"""Use JIT to speed up the verifications"""
primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
""" """
JAX native dot product attention implementation JAX native dot product attention implementation
""" """
attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
output = general_dot_product_attention( output = general_dot_product_attention(
query, query,
key, key,
...@@ -150,29 +195,43 @@ def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): ...@@ -150,29 +195,43 @@ def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
return output.astype(query.dtype) return output.astype(query.dtype)
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): def customcall_fused_dpa(
query,
key,
value,
bias,
mask,
seqlens_q,
seqlens_kv,
offsets_q,
offsets_kv,
dropout_rng,
**kwargs,
):
""" """
TE customcall dot product attention implementation TE customcall dot product attention implementation
""" """
attn_mask_type = kwargs["attn_mask_type"] qkv_layout = kwargs["qkv_layout"]
mask = make_mask(q_token, kv_token, attn_mask_type) is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
qkv_layout = kwargs.pop("qkv_layout")
match qkv_layout: match qkv_layout:
case QKVLayout.BS3HD: case QKVLayout.BS3HD | QKVLayout.T3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3) qkv = jnp.concatenate((query, key, value), axis=-3)
return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) qkv_args = (qkv,)
case QKVLayout.BSHD_BS2HD: case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3) kv = jnp.concatenate((key, value), axis=-3)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype( qkv_args = (query, kv)
query.dtype case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
) qkv_args = (query, key, value)
case QKVLayout.BSHD_BSHD_BSHD: case _:
return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype( raise ValueError(f"Unsupported {qkv_layout=}")
query.dtype if not is_thd:
) kwargs.pop("max_segments_per_seq")
return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_thd(
qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs
).astype(query.dtype)
class BiasShape(Enum): class BiasShape(Enum):
...@@ -207,11 +266,18 @@ class FusedAttnRunner: ...@@ -207,11 +266,18 @@ class FusedAttnRunner:
bias_shape: BiasShape bias_shape: BiasShape
def _check_configs(self): def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv: # TODO(rewang): probably adds this in is_fused_attn_available
pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.") if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
AttnMaskType.PADDING_MASK,
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: AttnMaskType.PADDING_CAUSAL_MASK,
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") ]:
pytest.skip("THD format requires padding masks.")
if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD:
if self.num_heads_q != self.num_heads_kv:
pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.")
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper( self.backend = FusedAttnHelper(
self.dtype, self.dtype,
...@@ -293,10 +359,78 @@ class FusedAttnRunner: ...@@ -293,10 +359,78 @@ class FusedAttnRunner:
pad_len = int(max_seqlen * pad_ratio) pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return valid_len, tokens return tokens, jnp.logical_not(tokens)
def generate_random_segment_ids(
batch_size, sequence_length, num_segments, seed, with_segment_pad=True
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
# [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
segment_pad = np.zeros((batch_size, sequence_length), dtype=int)
# Not include paddings
max_segment_size = sequence_length // num_segments
for i in range(batch_size):
current_pos = 0
segment_id = 1
for _ in range(num_segments):
segment_size = rng.integers(1, max_segment_size + 1)
if current_pos + segment_size > sequence_length:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
segment_pad[i, current_pos:sequence_length] = 1
return segment_ids, segment_pad
if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.num_segments_per_seq = 3
self.token_q, self.segment_pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
# TODO(rewang): Check if qkvpacked supported different q/kv
# TODO(rewang): Causal with different q/kv segment_id fails
if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
self.token_kv = self.token_q
self.segment_pad_kv = self.segment_pad_q
else:
self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024
)
self.pad_q = self.segment_pad_q
self.pad_kv = self.segment_pad_kv
else:
self.num_segments_per_seq = 1
self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.segment_pad_q = self.segment_pad_kv = None
self.mask = make_mask(
self.token_q,
self.token_kv,
self.segment_pad_q,
self.segment_pad_kv,
self.attn_mask_type,
)
self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
self.token_q, self.segment_pad_q
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
self.token_kv, self.segment_pad_kv
)
self.mask_for_customcall = None # THD format doesn't support mask
else:
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
self.mask_for_customcall = self.mask
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim) self.scaling_factor = 1.0 / sqrt(self.head_dim)
...@@ -307,7 +441,19 @@ class FusedAttnRunner: ...@@ -307,7 +441,19 @@ class FusedAttnRunner:
""" """
self._setup_inputs() self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
]
kwargs = { kwargs = {
"attn_bias_type": self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
...@@ -315,17 +461,19 @@ class FusedAttnRunner: ...@@ -315,17 +461,19 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq,
} }
# Convert the outputs to float32 for the elementwise comparison # Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32) primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32) reference_out = jax_dpa(*args, **kwargs)
if self.is_training and self.dropout_prob > 0.0: if self.is_training and self.dropout_prob > 0.0:
return return
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1) primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1) _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
...@@ -341,14 +489,28 @@ class FusedAttnRunner: ...@@ -341,14 +489,28 @@ class FusedAttnRunner:
def grad_func(func, *args, **kwargs): def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient # Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.valid_len_q * self.num_heads_q gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type): if is_causal_mask(self.attn_mask_type):
gradient_multiplier /= 10 gradient_multiplier /= 10
# Keep only valid result for the gradient # Keep only valid result for the gradient
ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1) ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
self.q,
self.k,
self.v,
self.bias,
self.mask_for_customcall,
self.seqlens_q,
self.seqlens_kv,
self.offsets_q,
self.offsets_kv,
self.dropout_rng,
]
kwargs = { kwargs = {
"attn_bias_type": self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
...@@ -356,6 +518,7 @@ class FusedAttnRunner: ...@@ -356,6 +518,7 @@ class FusedAttnRunner:
"dropout_probability": self.dropout_prob, "dropout_probability": self.dropout_prob,
"is_training": self.is_training, "is_training": self.is_training,
"qkv_layout": self.qkv_layout, "qkv_layout": self.qkv_layout,
"max_segments_per_seq": self.num_segments_per_seq,
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
...@@ -377,54 +540,56 @@ class FusedAttnRunner: ...@@ -377,54 +540,56 @@ class FusedAttnRunner:
) )
) )
primitive_out, primitive_dgrad = jitted_primitive(*args) primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
reference_out, reference_dgrad = jitted_reference(*args) reference_out, reference_dgrad = jitted_reference(*args)
# Skip elementwise comparison when dropout enabled # Skip elementwise comparison when dropout enabled
if self.dropout_prob > 0.0: if self.dropout_prob > 0.0:
return return
assert_allclose( assert_allclose(primitive_out, reference_out, dtype=self.dtype)
primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
)
def check_dqkv(primitive, reference, valid_len): def check_dqkv(primitive, reference, pad):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1) primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1) _split_valid_and_invalid(primitive, reference, pad)
)
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
# Convert the outputs to float32 for the elementwise comparison primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3]) reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
check_dqkv(primitive_dq, reference_dq, self.valid_len_q) check_dqkv(primitive_dq, reference_dq, self.pad_q)
check_dqkv(primitive_dk, reference_dk, self.valid_len_kv) check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.valid_len_kv) check_dqkv(primitive_dv, reference_dv, self.pad_kv)
if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
primitive_dbias = jnp.float32(primitive_dgrad[3]) primitive_dbias = primitive_dgrad[3]
reference_dbias = jnp.float32(reference_dgrad[3]) reference_dbias = reference_dgrad[3]
# Assume all batch has the same actual_seqlen, probably needs to extend the tests
bias_mask = self.mask[0, 0]
# Assert all masked dbias are 0s
assert_allclose( assert_allclose(
primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :], jnp.where(bias_mask, primitive_dbias, 0),
jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]), jnp.zeros_like(primitive_dbias),
dtype=self.dtype, dtype=self.dtype,
) )
# dbias padded part # dbias padded part
assert_allclose( assert_allclose(
primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :], jnp.where(bias_mask, primitive_dbias, 0),
reference_dbias[..., self.valid_len_q :, self.valid_len_kv :], jnp.where(bias_mask, reference_dbias, 0),
dtype=self.dtype, dtype=self.dtype,
) )
# dbias valid part # dbias valid part
assert_allclose( assert_allclose(
primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv], jnp.where(bias_mask, 0, primitive_dbias),
reference_dbias[..., : self.valid_len_q, : self.valid_len_kv], jnp.where(bias_mask, 0, reference_dbias),
dtype=self.dtype, dtype=self.dtype,
) )
...@@ -454,24 +619,21 @@ class FusedAttnRunner: ...@@ -454,24 +619,21 @@ class FusedAttnRunner:
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"), pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"), pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype", "b, s_q, s_kv, h_q, h_kv, d, dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d",
[ [
pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"), pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"), pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"), pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"), pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"),
pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"), pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -494,7 +656,7 @@ class TestFusedAttn: ...@@ -494,7 +656,7 @@ class TestFusedAttn:
pytest.param(False, id="INFERENCE"), pytest.param(False, id="INFERENCE"),
], ],
) )
def test_forward( def _test_forward(
b, b,
s_q, s_q,
s_kv, s_kv,
...@@ -511,6 +673,8 @@ class TestFusedAttn: ...@@ -511,6 +673,8 @@ class TestFusedAttn:
): ):
""" """
Test forward with parameterized configs Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
""" """
runner = FusedAttnRunner( runner = FusedAttnRunner(
b, b,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Optional, Tuple
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -12,6 +13,8 @@ import jax.numpy as jnp ...@@ -12,6 +13,8 @@ import jax.numpy as jnp
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine.transformer_engine_jax import NVTE_QKV_Format
from transformer_engine.transformer_engine_jax import nvte_get_qkv_format
from . import cpp_extensions as tex from . import cpp_extensions as tex
...@@ -43,11 +46,42 @@ class AttnMaskType(Enum): ...@@ -43,11 +46,42 @@ class AttnMaskType(Enum):
class QKVLayout(Enum): class QKVLayout(Enum):
"""QKV layout""" """
BSHD Format:
- BS3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
- BSHD_BS2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
- BSHD_BSHD_BSHD: q,k,v are seperate tensors with shape [b, s, h, d]
THD Format: Shape is same as BSHD layout but allow multiple segments packed in a sequence.
- T3HD: q,k,v are interleave packed as a tensor with shape [b, s, 3, h, d].
- THD_T2HD: q with shape [b, s, h, d] and kv are interleaved with shape [b, s, 2, h, d].
- THD_THD_THD: q,k,v are seperate tensors with shape [b, s, h, d]
"""
BS3HD = NVTE_QKV_Layout.NVTE_BS3HD BS3HD = NVTE_QKV_Layout.NVTE_BS3HD
BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD BSHD_BS2HD = NVTE_QKV_Layout.NVTE_BSHD_BS2HD
BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD BSHD_BSHD_BSHD = NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD
T3HD = NVTE_QKV_Layout.NVTE_T3HD
THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD
THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD
class QKVFormat(Enum):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
BSHD: q,k,v memory layout with [b, s, ..., h, d]
THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence.
"""
SBHD = NVTE_QKV_Format.NVTE_SBHD
BSHD = NVTE_QKV_Format.NVTE_BSHD
THD = NVTE_QKV_Format.NVTE_THD
def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
"""
return QKVFormat(nvte_get_qkv_format(qkv_layout.value))
def canonicalize_attn_mask_type(attn_mask_type: str): def canonicalize_attn_mask_type(attn_mask_type: str):
...@@ -102,414 +136,357 @@ def is_fused_attn_kernel_available( ...@@ -102,414 +136,357 @@ def is_fused_attn_kernel_available(
).is_fused_attn_kernel_available() ).is_fused_attn_kernel_available()
def fused_attn_qkvpacked( def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
qkv: jnp.ndarray, match qkv_layout:
bias: jnp.ndarray | None, case QKVLayout.BS3HD | QKVLayout.T3HD:
mask: jnp.ndarray, assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}"
seed: jnp.ndarray | None, batch, q_max_seqlen, *_ = qkv[0].shape
attn_bias_type: AttnBiasType, kv_max_seqlen = q_max_seqlen
attn_mask_type: AttnMaskType, case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
scaling_factor: float, assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}"
dropout_probability: float, batch, q_max_seqlen, *_ = qkv[0].shape
is_training: bool, kv_max_seqlen = qkv[1].shape[1]
): case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
""" assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}"
Fused attention with the qkvpacked inputs batch, q_max_seqlen, *_ = qkv[0].shape
""" kv_max_seqlen = qkv[1].shape[1]
output = _fused_attn_qkvpacked( case _:
qkv, raise ValueError(f"Unsupported {qkv_layout=}")
bias, return batch, q_max_seqlen, kv_max_seqlen
mask,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8))
def _fused_attn_qkvpacked(
qkv: jnp.ndarray,
bias: jnp.ndarray | None,
mask: jnp.ndarray,
seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
output, _ = _fused_attn_fwd_qkvpacked_rule( def fused_attn(
qkv, qkv: Tuple[jnp.ndarray, ...],
bias, bias: Optional[jnp.ndarray],
mask, mask: Optional[jnp.ndarray],
seed, seed: Optional[jnp.ndarray],
attn_bias_type,
attn_mask_type,
scaling_factor,
dropout_probability,
is_training,
)
return output
def _fused_attn_fwd_qkvpacked_rule(
qkv: jnp.ndarray,
bias: jnp.ndarray | None,
mask: jnp.ndarray,
seed: jnp.ndarray | None,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, seqlen, *_ = qkv.shape
actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = tex.fused_attn_fwd_qkvpacked(
qkv,
bias,
actual_seqlen,
seed,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen)
def _fused_attn_bwd_qkvpacked_rule(
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked(
qkv,
bias,
softmax_aux,
rng_state,
output,
dz,
actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None
return grad_qkv, grad_bias, None, None
_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule)
def fused_attn_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
): ):
""" """
Fused attention with the kvpacked inputs Perform non-THD (non-packed) cuDNN fused attention.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
mask (Optional[jnp.ndarray]):
An optional mask tensor to mask out the attention scores, `True` means mask out.
Intra-sequence padding is not valid. The padded tokens can only on the right-most.
Otherwise the results will be wrong.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
""" """
assert (
output = _fused_attn_kvpacked( get_qkv_format(qkv_layout) != QKVFormat.THD
q, ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format."
kv,
bias, # Check inputs qkv
mask, match qkv_layout:
seed, case NVTE_QKV_Layout.NVTE_BS3HD:
attn_bias_type=attn_bias_type, assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
attn_mask_type=attn_mask_type, case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
scaling_factor=scaling_factor, assert (
dropout_probability=dropout_probability, len(qkv) == 2
is_training=is_training, ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
) case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
assert (
return output len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) # convert the mask to seqlens, mask doesn't support ragged offsets
def _fused_attn_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
mask: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
output, _ = _fused_attn_fwd_kvpacked_rule(
q,
kv,
bias,
mask,
seed,
attn_bias_type,
attn_mask_type,
scaling_factor,
dropout_probability,
is_training,
)
return output
def _fused_attn_fwd_kvpacked_rule(
q,
kv,
bias,
mask,
seed,
attn_bias_type,
attn_mask_type,
scaling_factor,
dropout_probability,
is_training,
):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
s_kv = kv.shape[1] q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else: else:
assert mask is not None assert mask is not None
mask = jnp.logical_not(mask) mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) q_seq_lens = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0]
if attn_mask_type == AttnMaskType.PADDING_MASK: if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) kv_seq_lens = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0]
else: else:
# When mask is causal, the actual seqlen is not the last row, use max to find it # When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) kv_seq_lens = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = tex.fused_attn_fwd_kvpacked( output = _fused_attn(
q, qkv,
kv,
bias, bias,
q_actual_seqlen, q_seq_lens,
kv_actual_seqlen, kv_seq_lens,
None,
None,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, qkv_layout=qkv_layout,
dropout_probability=dropout_probability,
is_training=is_training,
)
output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context")
return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen)
def _fused_attn_bwd_kvpacked_rule(
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz
):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked(
q,
kv,
bias,
softmax_aux,
rng_state,
output,
dz,
q_actual_seqlen,
kv_actual_seqlen,
attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=1,
) )
if attn_bias_type == AttnBiasType.NO_BIAS: return output
grad_bias = None
return grad_q, grad_kv, grad_bias, None, None
_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule)
def fused_attn( def fused_attn_thd(
q: jnp.ndarray, qkv: Tuple[jnp.ndarray, ...],
k: jnp.ndarray, bias: Optional[jnp.ndarray],
v: jnp.ndarray, q_seq_lens: jnp.ndarray,
bias: jnp.ndarray, kv_seq_lens: jnp.ndarray,
mask: jnp.ndarray, q_seq_offsets: jnp.ndarray,
seed: jnp.ndarray, kv_seq_offsets: jnp.ndarray,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int = 1,
): ):
""" """
Dot product attention with the seperated query, key, value (Experimental) Perform THD (packed) cuDNN fused attention.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
q_seqlen (jnp.ndarray):
Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are
padded with -1.
kv_seqlen (jnp.ndarray):
Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions
are padded with -1.
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1].
Unused positions are padded with -1.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
Examples:
>>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens
>>> b, s, h, d = 2, 4, 12, 64
>>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16)
>>> # 3 segments in first seq, 2 segments in second seq
>>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]])
>>> # seq_offsets need to include the end offset of the last segments
>>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]])
>>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens,
q_seq_offsets, kv_seq_offsets, None,
AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK,
QKVLayout.T3HD, 0.125, 0, True, 3)
""" """
assert (
get_qkv_format(qkv_layout) == QKVFormat.THD
), "Please use transformer_engine.jax.attention.fused_attn for non-THD format."
# Check inputs qkv
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
assert q_seq_lens.shape == (batch, q_max_seqlen)
assert kv_seq_lens.shape == (batch, kv_max_seqlen)
assert q_seq_offsets.shape == (batch, q_max_seqlen + 1)
assert kv_seq_offsets.shape == (batch, kv_max_seqlen + 1)
output = _fused_attn( output = _fused_attn(
q, qkv,
k,
v,
bias, bias,
mask, q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13))
def _fused_attn( def _fused_attn(
q: jnp.ndarray, qkv: Tuple[jnp.ndarray, ...],
k: jnp.ndarray, bias: Optional[jnp.ndarray],
v: jnp.ndarray, q_seq_lens: jnp.ndarray,
bias: jnp.ndarray, kv_seq_lens: jnp.ndarray,
mask: jnp.ndarray, q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
seed: jnp.ndarray, seed: jnp.ndarray,
attn_bias_type: AttnBiasType, attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int,
): ):
output, _ = _fused_attn_fwd_rule( output, _ = _fused_attn_fwd_rule(
q, qkv,
k,
v,
bias, bias,
mask, q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
) )
return output return output
def _fused_attn_fwd_rule( def _fused_attn_fwd_rule(
q, qkv,
k,
v,
bias, bias,
mask, q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
batch, s_q, *_ = q.shape
s_kv = k.shape[1]
q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32)
kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32)
else:
assert mask is not None
mask = jnp.logical_not(mask)
q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
if attn_mask_type == AttnMaskType.PADDING_MASK:
kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
else:
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = tex.fused_attn_fwd( output, softmax_aux, rng_state = tex.fused_attn_fwd(
q, qkv,
k,
v,
bias, bias,
q_actual_seqlen, q_seq_lens,
kv_actual_seqlen, kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
seed, seed,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
output = checkpoint_name(output, "context") output = checkpoint_name(output, "context")
softmax_aux = checkpoint_name(softmax_aux, "context") softmax_aux = checkpoint_name(softmax_aux, "context")
rng_state = checkpoint_name(rng_state, "context") rng_state = checkpoint_name(rng_state, "context")
return output, ( return output, (
q, qkv,
k,
v,
bias, bias,
q_seq_lens,
kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
q_actual_seqlen,
kv_actual_seqlen,
) )
def _fused_attn_bwd_rule( def _fused_attn_bwd_rule(
attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, ctx, dz attn_bias_type,
attn_mask_type,
qkv_layout,
scaling_factor,
dropout_probability,
is_training,
max_segments_per_seq,
ctx,
dz,
): ):
q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx (
qkv,
grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd( bias,
q, q_seq_lens,
k, kv_seq_lens,
v, q_seq_offsets,
kv_seq_offsets,
softmax_aux,
rng_state,
output,
) = ctx
grad_qkv, grad_bias = tex.fused_attn_bwd(
qkv,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
output, output,
dz, dz,
q_actual_seqlen, q_seq_lens,
kv_actual_seqlen, kv_seq_lens,
q_seq_offsets,
kv_seq_offsets,
attn_bias_type=attn_bias_type.value, attn_bias_type=attn_bias_type.value,
attn_mask_type=attn_mask_type.value, attn_mask_type=attn_mask_type.value,
qkv_layout=qkv_layout.value,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
if attn_bias_type == AttnBiasType.NO_BIAS: if attn_bias_type == AttnBiasType.NO_BIAS:
grad_bias = None grad_bias = None
return grad_qkv, grad_bias, None, None, None, None, None
return grad_q, grad_k, grad_v, grad_bias, None, None
_fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, reduce from functools import partial, reduce
import operator import operator
from typing import Optional, Tuple
import warnings import warnings
import jax.numpy as jnp import jax.numpy as jnp
...@@ -18,7 +19,9 @@ from transformer_engine.transformer_engine_jax import ( ...@@ -18,7 +19,9 @@ from transformer_engine.transformer_engine_jax import (
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_QKV_Layout, NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
nvte_get_qkv_format,
) )
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper from .custom_call import custom_caller, CustomCallArgsWrapper
...@@ -37,10 +40,6 @@ from ..sharding import ( ...@@ -37,10 +40,6 @@ from ..sharding import (
__all__ = [ __all__ = [
"FusedAttnHelper", "FusedAttnHelper",
"fused_attn_fwd_qkvpacked",
"fused_attn_bwd_qkvpacked",
"fused_attn_fwd_kvpacked",
"fused_attn_bwd_kvpacked",
"fused_attn_fwd", "fused_attn_fwd",
"fused_attn_bwd", "fused_attn_bwd",
] ]
...@@ -88,18 +87,18 @@ class FusedAttnHelper: ...@@ -88,18 +87,18 @@ class FusedAttnHelper:
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval""" """Parse qkv aval"""
match qkv_layout: match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_batch_shape = q_batch_shape kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads num_gqa_groups = attn_heads
kv_head_dim = q_head_dim kv_head_dim = q_head_dim
assert nqkv == 3 assert nqkv == 3
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
assert nkv == 2 assert nkv == 2
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
assert k_aval.shape == v_aval.shape assert k_aval.shape == v_aval.shape
...@@ -158,8 +157,9 @@ def generate_cu_seqlen(actual_seqlen): ...@@ -158,8 +157,9 @@ def generate_cu_seqlen(actual_seqlen):
""" """
Generating cumsum seqlen for a batch Generating cumsum seqlen for a batch
""" """
cu_seqlen = jnp.cumsum(actual_seqlen) cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1)
cu_seqlen = jnp.hstack((0, cu_seqlen)) cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen)
cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1)
return cu_seqlen return cu_seqlen
...@@ -170,7 +170,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -170,7 +170,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name = "te_fused_attn_forward" name = "te_fused_attn_forward"
multiple_results = True multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12) impl_static_args = (9, 10, 11, 12, 13, 14, 15)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -182,6 +182,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -182,6 +182,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias_aval, bias_aval,
q_seqlen_or_cu_seqlen_aval, q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
seed_aval, seed_aval,
*, *,
attn_bias_type, attn_bias_type,
...@@ -190,6 +192,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -190,6 +192,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
""" """
Fused attention fwd abstract Fused attention fwd abstract
...@@ -227,7 +230,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -227,7 +230,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else: else:
raise ValueError(f"Unsupported {backend=}") raise ValueError(f"Unsupported {backend=}")
...@@ -266,6 +269,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -266,6 +269,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
qkv_layout, qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
is_training, is_training,
max_segments_per_seq,
) )
wkspace_aval = q_aval.update( wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -292,6 +296,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -292,6 +296,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias, bias,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed, seed,
*, *,
attn_bias_type, attn_bias_type,
...@@ -300,11 +306,22 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -300,11 +306,22 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
""" """
Fused attention fwd lowering rules Fused attention fwd lowering rules
""" """
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed] operands = [
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
...@@ -337,6 +354,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -337,6 +354,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
num_gqa_groups, num_gqa_groups,
bias_heads, bias_heads,
head_dim, head_dim,
max_segments_per_seq,
wkspace_aval.size, wkspace_aval.size,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
...@@ -360,6 +378,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -360,6 +378,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias, bias,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
q_seq_offsets,
k_seq_offsets,
seed, seed,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
...@@ -367,11 +387,64 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -367,11 +387,64 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
assert FusedAttnFwdPrimitive.inner_primitive is not None assert FusedAttnFwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen) if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
def _fix_len_take(x, condition):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=-1)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
offsets_2d = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
offsets,
)
return offsets_2d
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_T2HD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)
# Set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q, q,
...@@ -380,6 +453,8 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -380,6 +453,8 @@ class FusedAttnFwdPrimitive(BasePrimitive):
bias, bias,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -387,6 +462,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -387,6 +462,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
return output, softmax_aux, rng_state return output, softmax_aux, rng_state
...@@ -401,6 +477,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -401,6 +477,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None assert FusedAttnFwdPrimitive.outer_primitive is not None
...@@ -416,6 +493,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -416,6 +493,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
), ),
out_bdims, out_bdims,
) )
...@@ -428,29 +506,30 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -428,29 +506,30 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del attn_bias_type, attn_mask_type, scaling_factor del attn_bias_type, attn_mask_type, scaling_factor
del dropout_probability, is_training, result_infos del dropout_probability, is_training, max_segments_per_seq, result_infos
q_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
match qkv_layout: match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
# q_spec = (...batch, q_seqlen, head, hidden) # q_spec = (...batch, q_seqlen, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
softmax_aux_sharding = NamedSharding( softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
) )
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
# q_spec = (...batch, q_seqlen, head, hidden) # q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding( softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]) mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])
) )
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
# q_spec = (...batch, q_seqlen, head, hidden) # q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
...@@ -470,6 +549,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -470,6 +549,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -489,6 +569,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -489,6 +569,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
...@@ -503,7 +584,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -503,7 +584,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name = "te_fused_attn_backward" name = "te_fused_attn_backward"
multiple_results = True multiple_results = True
impl_static_args = (10, 11, 12, 13, 14, 15) impl_static_args = (12, 13, 14, 15, 16, 17, 18)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -517,8 +598,10 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -517,8 +598,10 @@ class FusedAttnBwdPrimitive(BasePrimitive):
rng_state_aval, rng_state_aval,
output_aval, output_aval,
doutput_aval, doutput_aval,
q_cu_seqlen_aval, q_seqlen_or_cu_seqlen_aval,
kv_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
*, *,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
...@@ -526,6 +609,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -526,6 +609,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
""" """
Fused attention bwd abstract Fused attention bwd abstract
...@@ -538,7 +622,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -538,7 +622,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
...@@ -567,6 +651,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -567,6 +651,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
qkv_layout, qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(q_aval.dtype),
is_training, is_training,
max_segments_per_seq,
) )
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
...@@ -600,6 +685,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -600,6 +685,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
*, *,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
...@@ -607,6 +694,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -607,6 +694,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
""" """
Fused attention bwd lowering rules Fused attention bwd lowering rules
...@@ -622,6 +710,8 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -622,6 +710,8 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
] ]
operand_shapes = map(lambda x: x.type.shape, operands) operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [ out_types = [
...@@ -656,6 +746,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -656,6 +746,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
num_gqa_groups, num_gqa_groups,
bias_heads, bias_heads,
head_dim, head_dim,
max_segments_per_seq,
wkspace_aval.size, wkspace_aval.size,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
...@@ -683,17 +774,73 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -683,17 +774,73 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput, doutput,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type, attn_bias_type,
attn_mask_type, attn_mask_type,
qkv_layout, qkv_layout,
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
assert FusedAttnBwdPrimitive.inner_primitive is not None assert FusedAttnBwdPrimitive.inner_primitive is not None
q_cu_seqlen = generate_cu_seqlen(q_seqlen) if nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD:
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
def _fix_len_take(x, condition):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
# TODO(rewang): try indices_are_sorted
y = jnp.take(x, indices, fill_value=-1)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
offsets_2d = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
offsets,
)
return offsets_2d
match qkv_layout:
case NVTE_QKV_Layout.NVTE_T3HD:
kv_max_seqlen = q_max_seqlen = q.shape[-4]
kv_batch = q_batch = reduce(operator.mul, q.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_T2HD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-4]
kv_batch = reduce(operator.mul, k.shape[:-4])
case NVTE_QKV_Layout.NVTE_THD_THD_THD:
q_max_seqlen = q.shape[-3]
q_batch = reduce(operator.mul, q.shape[:-3])
kv_max_seqlen = k.shape[-3]
kv_batch = reduce(operator.mul, k.shape[:-3])
# Gather valid q_seqlen, which is greater than 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0)
k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0)
# Set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets)
k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q, q,
...@@ -706,12 +853,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -706,12 +853,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
return dq, dk, dv, dbias return dq, dk, dv, dbias
...@@ -726,6 +876,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -726,6 +876,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
): ):
check_valid_batch_dims(batch_dims) check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None assert FusedAttnBwdPrimitive.outer_primitive is not None
...@@ -741,6 +892,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -741,6 +892,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
), ),
out_bdims, out_bdims,
) )
...@@ -753,11 +905,12 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -753,11 +905,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, max_segments_per_seq
del dropout_probability, is_training, result_infos del dropout_probability, is_training, result_infos
q_spec = get_padded_spec(arg_infos[0]) q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1]) k_spec = get_padded_spec(arg_infos[1])
...@@ -777,6 +930,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -777,6 +930,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
scaling_factor, scaling_factor,
dropout_probability, dropout_probability,
is_training, is_training,
max_segments_per_seq,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -794,7 +948,18 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -794,7 +948,18 @@ class FusedAttnBwdPrimitive(BasePrimitive):
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl( def sharded_impl(
q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
): ):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q, q,
...@@ -807,12 +972,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -807,12 +972,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
doutput, doutput,
q_cu_seqlen, q_cu_seqlen,
kv_cu_seqlen, kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
global_dbias = local_dbias global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
...@@ -825,245 +993,182 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -825,245 +993,182 @@ class FusedAttnBwdPrimitive(BasePrimitive):
register_primitive(FusedAttnBwdPrimitive) register_primitive(FusedAttnBwdPrimitive)
def fused_attn_fwd_qkvpacked( def fused_attn_fwd(
qkv: jnp.ndarray, qkv: Tuple[jnp.ndarray, ...],
bias: jnp.ndarray, bias: Optional[jnp.ndarray],
seqlen: jnp.ndarray,
seed: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE self fused attention fwd
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
"""
checker = _FusedAttnRNGStateChecker()
seed = checker.check_seed(seed, dropout_probability, is_training)
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
_not_used = jnp.zeros(0, qkv.dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind(
qkv,
_not_used,
_not_used,
bias,
seqlen,
seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
def fused_attn_bwd_qkvpacked(
qkv: jnp.ndarray,
bias: jnp.ndarray,
softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE self fused attention bwd
Return the gradients of self fused attention with packed qkv input
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv.dtype)
dummy_input = jnp.zeros(0, dtype=qkv.dtype)
dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
qkv,
dummy_input,
dummy_input,
bias,
softmax_aux,
rng_state,
output,
doutput,
seqlen,
seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return dqkv, dbias
def fused_attn_fwd_kvpacked(
q: jnp.ndarray,
kv: jnp.ndarray,
bias: jnp.ndarray,
q_seqlen: jnp.ndarray, q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
seed: jnp.ndarray, q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
seed: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
): max_segments_per_seq: int,
) -> jnp.ndarray:
""" """
Wrapper for TE fused attention fwd with kvpacked inputs Perform the forward pass of with cuDNN fused attention implementations.
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
""" """
checker = _FusedAttnRNGStateChecker() seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
seed = checker.check_seed(seed, dropout_probability, is_training)
assert (q_seq_offsets is None) == (
kv_seq_offsets is None
), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
match qkv_layout:
case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=q.dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind( return FusedAttnFwdPrimitive.outer_primitive.bind(
q, *qkv_for_primitive,
kv,
jnp.zeros(0, q.dtype),
bias, bias,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
seed, seed,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
def fused_attn_bwd_kvpacked( def fused_attn_bwd(
q: jnp.ndarray, qkv: Tuple[jnp.ndarray, ...],
kv: jnp.ndarray, bias: Optional[jnp.ndarray],
bias: jnp.ndarray,
softmax_aux: jnp.ndarray, softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray, rng_state: jnp.ndarray,
output: jnp.ndarray, output: jnp.ndarray,
doutput: jnp.ndarray, doutput: jnp.ndarray,
q_seqlen: jnp.ndarray, q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
q_seq_offsets: Optional[jnp.ndarray],
kv_seq_offsets: Optional[jnp.ndarray],
attn_bias_type: NVTE_Bias_Type, attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type, attn_mask_type: NVTE_Mask_Type,
qkv_layout: NVTE_QKV_Layout,
scaling_factor: float, scaling_factor: float,
dropout_probability: float, dropout_probability: float,
is_training: bool, is_training: bool,
max_segments_per_seq: int,
): ):
""" """
Wrapper for TE fused attention bwd with kvpacked inputs Perform the backward pass of the cuDNN fused attention implementations.
Return the gradients of fused attention with packed kv input
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
used in the forward pass. It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
doutput (jnp.ndarray): The gradient with respect to the output.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (NVTE_Bias_Type): Type of attention bias.
attn_mask_type (NVTE_Mask_Type): Type of attention mask.
qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
same format as the input `qkv`.
- The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
""" """
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
dummy_input = jnp.zeros(0, q.dtype)
dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
q,
kv,
dummy_input,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
return dq, dkv, dbias
assert (q_seq_offsets is None) == (
def fused_attn_fwd( kv_seq_offsets is None
q: jnp.ndarray, ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values."
k: jnp.ndarray, is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD
v: jnp.ndarray,
bias: jnp.ndarray, # For optional tensors, which custom calls doesn't support None
q_seqlen: jnp.ndarray, _not_used = jnp.zeros(0, dtype=qkv[0].dtype)
kv_seqlen: jnp.ndarray,
seed: jnp.ndarray, match qkv_layout:
attn_bias_type: NVTE_Bias_Type, case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD:
attn_mask_type: NVTE_Mask_Type, assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
scaling_factor: float, qkv_for_primitive = [*qkv, _not_used, _not_used]
dropout_probability: float, case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD:
is_training: bool, assert (
): len(qkv) == 2
""" ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
Wrapper for TE fused attention fwd, where query, key, value are seperated tensors qkv_for_primitive = [*qkv, _not_used]
Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD:
""" assert (
checker = _FusedAttnRNGStateChecker() len(qkv) == 3
seed = checker.check_seed(seed, dropout_probability, is_training) ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=q.dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
return FusedAttnFwdPrimitive.outer_primitive.bind( *qkv_grads, bias_grad = FusedAttnBwdPrimitive.outer_primitive.bind(
q, *qkv_for_primitive,
k,
v,
bias,
q_seqlen,
kv_seqlen,
seed,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
)
def fused_attn_bwd(
q: jnp.ndarray,
k: jnp.ndarray,
v: jnp.ndarray,
bias: jnp.ndarray,
softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
q_seqlen: jnp.ndarray,
kv_seqlen: jnp.ndarray,
attn_bias_type: NVTE_Bias_Type,
attn_mask_type: NVTE_Mask_Type,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
):
"""
Wrapper for TE fused attention bwd
Return the gradients of fused attention with seperated query, key, value tensors
"""
if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=q.dtype)
return FusedAttnBwdPrimitive.outer_primitive.bind(
q,
k,
v,
bias, bias,
softmax_aux, softmax_aux,
rng_state, rng_state,
...@@ -1071,10 +1176,14 @@ def fused_attn_bwd( ...@@ -1071,10 +1176,14 @@ def fused_attn_bwd(
doutput, doutput,
q_seqlen, q_seqlen,
kv_seqlen, kv_seqlen,
q_seq_offsets if is_ragged else _not_used,
kv_seq_offsets if is_ragged else _not_used,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, qkv_layout=qkv_layout,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_probability, dropout_probability=dropout_probability,
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
) )
return tuple(qkv_grads[: len(qkv)]), bias_grad
...@@ -137,6 +137,7 @@ struct CustomCallFusedAttnDescriptor { ...@@ -137,6 +137,7 @@ struct CustomCallFusedAttnDescriptor {
size_t num_gqa_groups; size_t num_gqa_groups;
size_t bias_heads; size_t bias_heads;
size_t head_dim; size_t head_dim;
size_t max_segments_per_seq;
size_t wkspace_size; size_t wkspace_size;
float scaling_factor; float scaling_factor;
float dropout_probability; float dropout_probability;
...@@ -151,9 +152,9 @@ struct CustomCallFusedAttnDescriptor { ...@@ -151,9 +152,9 @@ struct CustomCallFusedAttnDescriptor {
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
bool is_training); NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training);
// Transpose // Transpose
...@@ -249,7 +250,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -249,7 +250,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -257,7 +259,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -257,7 +259,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -104,7 +104,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -104,7 +104,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) {
// For qkv_packed // For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
...@@ -128,128 +129,50 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -128,128 +129,50 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
}
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
}
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads,
size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype,
bool is_training) {
auto output_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto bias_shape = std::vector<size_t>{1, attn_heads, q_max_seqlen, kv_max_seqlen};
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
// F16 doesn't use s_tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto dummy_ragged_offset_tensor = auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32); // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
assert(q_max_seqlen == kv_max_seqlen); for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) {
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; // the last one is the largest which will be the returned workspace size
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto q_cu_seqlens_tensor =
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), auto kv_cu_seqlens_tensor =
s_tensor.data(), // not used for F16 TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
s_tensor.data(), // not used for F16 auto ragged_offset_tensor =
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
bias_type, mask_type, query_workspace_tensor.data(), nullptr); nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { o_tensor.data(), &aux_output_tensors,
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); scaling_factor, dropout_probability, qkv_layout, bias_type,
auto kv_shape = std::vector<size_t>{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; mask_type, query_workspace_tensor.data(), nullptr);
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked(
nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
s_tensor.data(), // not used for F16 ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
s_tensor.data(), // not used for F16 q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), bias_type, mask_type, query_workspace_tensor.data(), nullptr);
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
nullptr); q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { ragged_offset_tensor.data(), ragged_offset_tensor.data(),
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); query_workspace_tensor.data(), nullptr);
auto k_shape = std::vector<size_t>{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; } else {
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); NVTE_ERROR("Unsupported QKVLayout.");
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); }
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
} }
auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape());
...@@ -260,18 +183,23 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -260,18 +183,23 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */ /* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3]; void *bias = buffers[3];
void *q_cu_seqlens = buffers[4]; void *q_cu_seqlens = buffers[4];
void *kv_cu_seqlens = buffers[5]; void *kv_cu_seqlens = buffers[5];
void *seed = buffers[6]; void *q_seq_offsets = is_ragged ? buffers[6] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[7] : nullptr;
void *seed = buffers[8];
/* Output buffer from XLA */ /* Output buffer from XLA */
void *output = buffers[7]; void *output = buffers[9];
void *softmax_aux = buffers[8]; void *softmax_aux = buffers[10];
void *rng_state = buffers[9]; void *rng_state = buffers[11];
void *workspace = buffers[10]; void *workspace = buffers[12];
/* Descriptor */ /* Descriptor */
auto input_batch = descriptor.input_batch; auto input_batch = descriptor.input_batch;
...@@ -286,8 +214,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -286,8 +214,9 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto is_training = descriptor.is_training;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
/* Input tensors */ /* Input tensors */
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -296,14 +225,33 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -296,14 +225,33 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
cudaMemsetAsync(output, 0,
input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream);
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto o_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto o_tensor = TensorWrapper(output, o_shape, dtype); auto o_tensor = TensorWrapper(output, o_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
...@@ -323,19 +271,18 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -323,19 +271,18 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype); descriptor.wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0]; auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, descriptor.scaling_factor, rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
...@@ -345,11 +292,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -345,11 +292,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, bias_type, mask_type, workspace_tensor.data(), stream);
workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
...@@ -359,13 +305,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -359,13 +305,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto v = buffers[2]; auto v = buffers[2];
auto v_shape = k_shape; auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
bias_type, mask_type, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -377,46 +322,89 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -377,46 +322,89 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; size_t max_segments_per_seq) {
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; // For qkv_packed
auto v_shape = k_shape; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
// For kv_packed
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype);
// For separate q, k, v
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto k_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto v_shape = k_shape;
auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); auto v_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype);
auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, dtype);
// F16 doesn't use this tensor // F16 doesn't use this tensor
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype);
auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype);
auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32); auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
doutput_tensor.data(), // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time
s_tensor.data(), // not used for F16 size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch;
s_tensor.data(), // not used for F16 for (auto num_segments = input_batch; num_segments <= max_num_segments; ++num_segments) {
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), // the last one is the largest which will be the returned workspace size
dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), auto q_cu_seqlens_tensor =
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, auto kv_cu_seqlens_tensor =
bias_type, mask_type, query_workspace_tensor.data(), nullptr); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
}
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
...@@ -426,6 +414,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -426,6 +414,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
const CustomCallFusedAttnDescriptor &descriptor = const CustomCallFusedAttnDescriptor &descriptor =
*UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len); *UnpackOpaque<CustomCallFusedAttnDescriptor>(opaque, opaque_len);
auto qkv_layout = descriptor.qkv_layout;
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;
/* Input buffers from XLA */ /* Input buffers from XLA */
/* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */
void *bias = buffers[3]; void *bias = buffers[3];
...@@ -435,11 +426,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -435,11 +426,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
void *doutput = buffers[7]; void *doutput = buffers[7];
void *q_cu_seqlens = buffers[8]; void *q_cu_seqlens = buffers[8];
void *kv_cu_seqlens = buffers[9]; void *kv_cu_seqlens = buffers[9];
void *q_seq_offsets = is_ragged ? buffers[10] : nullptr;
void *k_seq_offsets = is_ragged ? buffers[11] : nullptr;
/* Output buffer from XLA */ /* Output buffer from XLA */
/* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */ /* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */
void *dbias = buffers[13]; void *dbias = buffers[15];
void *workspace = buffers[14]; void *workspace = buffers[16];
/* Descriptor */ /* Descriptor */
auto input_batch = descriptor.input_batch; auto input_batch = descriptor.input_batch;
...@@ -454,8 +447,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -454,8 +447,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dropout_probability = descriptor.dropout_probability; auto dropout_probability = descriptor.dropout_probability;
auto bias_type = descriptor.bias_type; auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type; auto mask_type = descriptor.mask_type;
auto qkv_layout = descriptor.qkv_layout;
auto dtype = descriptor.dtype; auto dtype = descriptor.dtype;
auto max_segments_per_seq = descriptor.max_segments_per_seq;
/* Input tensors */ /* Input tensors */
auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto output_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -463,13 +456,30 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -463,13 +456,30 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto output_tensor = TensorWrapper(output, output_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype);
auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype);
size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments
if (is_ragged) {
// workspace can be reused here as it is not used with cuDNN graph at the same time
size_t runtime_num_segments_q =
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
size_t runtime_num_segments_kv =
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv);
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq);
num_segments = runtime_num_segments_q;
}
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto q_seq_offsets_tensor =
TensorWrapper(q_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
auto k_seq_offsets_tensor =
TensorWrapper(k_seq_offsets, std::vector<size_t>{num_segments + 1}, DType::kInt32);
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto q_cu_seqlens_tensor =
TensorWrapper(q_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
auto kv_cu_seqlens_tensor =
TensorWrapper(kv_cu_seqlens, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Auxiliary tensors (propagated from the forward pass) */ /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
...@@ -486,42 +496,54 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -486,42 +496,54 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto wkspace_dtype = descriptor.wkspace_dtype; auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv = buffers[0]; auto qkv = buffers[0];
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10]; auto dqkv = buffers[12];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
if (is_ragged) {
size_t dqkv_size =
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream); bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv = buffers[1]; auto kv = buffers[1];
auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_shape = std::vector<size_t>{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim};
auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); auto kv_tensor = TensorWrapper(kv, kv_shape, dtype);
auto dq = buffers[10]; auto dq = buffers[12];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dkv = buffers[11]; auto dkv = buffers[13];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
if (is_ragged) {
size_t dq_size =
std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
size_t dkv_size =
std::accumulate(kv_shape.cbegin(), kv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, dkv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
...@@ -531,21 +553,31 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -531,21 +553,31 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto v = buffers[2]; auto v = buffers[2];
auto v_shape = k_shape; auto v_shape = k_shape;
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
auto dq = buffers[10]; auto dq = buffers[12];
auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype);
auto dk = buffers[11]; auto dk = buffers[13];
auto dk_tensor = TensorWrapper(dk, k_shape, dtype); auto dk_tensor = TensorWrapper(dk, k_shape, dtype);
auto dv = buffers[12]; auto dv = buffers[14];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype); auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) {
size_t dq_size =
std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
size_t dk_size =
std::accumulate(k_shape.cbegin(), k_shape.cend(), 1, std::multiplies<size_t>());
size_t dv_size = dk_size;
cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, dk_size * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, dv_size * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
......
...@@ -66,13 +66,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin ...@@ -66,13 +66,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin
pybind11::bytes PackCustomCallFusedAttnDescriptor( pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
bool is_training) { NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
return PackOpaque(CustomCallFusedAttnDescriptor{ return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
dtype, wkspace_dtype, is_training}); mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
} }
} // namespace jax } // namespace jax
......
...@@ -67,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -67,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -92,7 +93,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -92,7 +93,15 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local()) pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU) .value("GELU", NVTE_Activation_Type::GELU)
......
...@@ -45,5 +45,29 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -45,5 +45,29 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -28,6 +28,8 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q ...@@ -28,6 +28,8 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream); cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager { class cudaDevicePropertiesManager {
public: public:
static cudaDevicePropertiesManager &Instance() { static cudaDevicePropertiesManager &Instance() {
......
...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP ...@@ -26,7 +26,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout from ..attention import AttnBiasType, AttnMaskType, QKVLayout
from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, canonicalize_attn_mask_type
from ..attention import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from ..attention import fused_attn
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
...@@ -268,6 +268,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -268,6 +268,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
# TODO(rewang): integrate THD format
if self.qkv_layout == QKVLayout.BS3HD: if self.qkv_layout == QKVLayout.BS3HD:
"""qkvpacked format, treat """qkvpacked format, treat
query: qkvpacked tensor, shape = [..., 3, h, d] query: qkvpacked tensor, shape = [..., 3, h, d]
...@@ -277,13 +278,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -277,13 +278,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
qkv_packed = query qkv_packed = query
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_qkvpacked( x = fused_attn(
qkv_packed, (qkv_packed,),
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -298,14 +300,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -298,14 +300,14 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = query.transpose([1, 0, 2, 3]) query = query.transpose([1, 0, 2, 3])
kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4])
x = fused_attn_kvpacked( x = fused_attn(
query, (query, kv_packed),
kv_packed,
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
...@@ -316,14 +318,13 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -316,14 +318,13 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
key = key.transpose([1, 0, 2, 3]) key = key.transpose([1, 0, 2, 3])
value = value.transpose([1, 0, 2, 3]) value = value.transpose([1, 0, 2, 3])
x = fused_attn( x = fused_attn(
query, (query, key, value),
key,
value,
bias, bias,
mask, mask,
seed, seed,
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
qkv_layout=self.qkv_layout,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.attention_dropout, dropout_probability=self.attention_dropout,
is_training=not deterministic, is_training=not deterministic,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment