Unverified Commit e9606077 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Add THD support for cuDNN attention (#832)



* add THD support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add seq_offsets_o and use new offset calculation
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* addition to previous commit; fix unit test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add None for offset_o gradient
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: test padding between sequences
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: fix tests for padding between sequences
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix tests for sbhd/bshd layouts; clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend and add tests for max_seqlen_q=1 and d=256 for inference
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test sbhd/bshd layouts for sq1, d256 inference case
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace wording from accumulative to cumulative
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add offset tensors to custom fp8 mha tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add version control for cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add sm>=90 constraint for thd support
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix cuDNN support for sq=1, d=256
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint and minor tweak for fp8 tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* modify cudnn version and restrict MQA/GQA support for THD
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add notes for seq offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add dummy tensor to pass jax build
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add dummy tensor to pass paddle build
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarcyanguwa <8636796+cyanguwa@users.noreply.github.com>
parent e8a17d1e
Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348 Subproject commit b740542818f36857acf7f9853f749bbad4118c65
...@@ -37,7 +37,7 @@ from transformer_engine.jax.softmax import SoftmaxType ...@@ -37,7 +37,7 @@ from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(128, 32, 512), (512, 32, 512)] DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True] ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
...@@ -736,7 +736,7 @@ class TestDotProductAttn(TestLayer): ...@@ -736,7 +736,7 @@ class TestDotProductAttn(TestLayer):
q_key, k_key, v_key = jax.random.split(key, 3) q_key, k_key, v_key = jax.random.split(key, 3)
b, s, *_ = shape b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]: if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
...@@ -786,6 +786,7 @@ class MultiHeadAttnAttr: ...@@ -786,6 +786,7 @@ class MultiHeadAttnAttr:
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads' NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope' LORA_SCOPE = 'low_rank_adaptation_scope'
...@@ -795,42 +796,48 @@ class MultiHeadAttnAttr: ...@@ -795,42 +796,48 @@ class MultiHeadAttnAttr:
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding' ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding' ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding' ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
...@@ -839,7 +846,8 @@ class MultiHeadAttnAttr: ...@@ -839,7 +846,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
...@@ -848,7 +856,8 @@ class MultiHeadAttnAttr: ...@@ -848,7 +856,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
...@@ -857,7 +866,8 @@ class MultiHeadAttnAttr: ...@@ -857,7 +866,8 @@ class MultiHeadAttnAttr:
ROPE_GROUP_METHOD: 'alternate', ROPE_GROUP_METHOD: 'alternate',
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -865,7 +875,8 @@ class MultiHeadAttnAttr: ...@@ -865,7 +875,8 @@ class MultiHeadAttnAttr:
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding', ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all' LORA_SCOPE: 'all',
TRANSPOSE_BS: False,
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -873,7 +884,8 @@ class MultiHeadAttnAttr: ...@@ -873,7 +884,8 @@ class MultiHeadAttnAttr:
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all' LORA_SCOPE: 'all',
TRANSPOSE_BS: True,
}] }]
...@@ -882,7 +894,9 @@ class TestMultiHeadAttn(TestLayer): ...@@ -882,7 +894,9 @@ class TestMultiHeadAttn(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234) key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2) q_key, kv_key = jax.random.split(key, 2)
s, b, *_ = shape b, s, *_ = shape
if self.attrs[MultiHeadAttnAttr.TRANSPOSE_BS]:
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
...@@ -906,7 +920,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -906,7 +920,7 @@ class TestMultiHeadAttn(TestLayer):
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
fuse_qkv_params = True fuse_qkv_params = True
transpose_batch_sequence = True transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False scale_attn_logits = False
scaled_query_init = True scaled_query_init = True
float32_logits = False float32_logits = False
...@@ -962,6 +976,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -962,6 +976,7 @@ class TestMultiHeadAttn(TestLayer):
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
...@@ -977,7 +992,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -977,7 +992,7 @@ class TestMultiHeadAttn(TestLayer):
fp8_format, fp8_format,
rtol=1e-05, rtol=1e-05,
atol=1e-08): atol=1e-08):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -1240,7 +1255,7 @@ class TestTransformer(TestLayer): ...@@ -1240,7 +1255,7 @@ class TestTransformer(TestLayer):
q_key, kv_key = jax.random.split(key, 2) q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]: if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
......
This diff is collapsed.
...@@ -135,20 +135,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -135,20 +135,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} }
if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80)
|| (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)))
&& (max_seqlen_q % 64 == 0) && ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0)
&& (max_seqlen_kv % 64 == 0) || (cudnn_runtime_version >= 90000))
&& ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups)
|| (cudnn_runtime_version >= 8907)) || (cudnn_runtime_version >= 8907))
&& ((head_dim <= 128) && (head_dim % 8 == 0)) && ((head_dim <= 128 && head_dim % 8 == 0)
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90000
&& head_dim <= 256 && head_dim % 8 == 0))
&& ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| ((cudnn_runtime_version >= 8906) || ((cudnn_runtime_version >= 8906)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
|| (bias_type == NVTE_Bias_Type::NVTE_ALIBI || (bias_type == NVTE_Bias_Type::NVTE_ALIBI
&& attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
&& attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
&& sm_arch_ == 90) && sm_arch_ >= 90)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ == 90))) && sm_arch_ >= 90)))
|| ((cudnn_runtime_version >= 90000) || ((cudnn_runtime_version >= 90000)
&& (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
&& sm_arch_ >= 80))) && sm_arch_ >= 80)))
...@@ -163,6 +167,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -163,6 +167,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
|| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
&& bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD)
|| (sm_arch_ >= 90 && cudnn_runtime_version >= 90100
&& num_attn_heads == num_gqa_groups
&& qkv_format == NVTE_QKV_Format::NVTE_THD)
|| (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true; flag_arb = true;
} }
...@@ -211,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -211,6 +218,10 @@ void nvte_fused_attn_fwd_qkvpacked(
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -222,6 +233,10 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -222,6 +233,10 @@ void nvte_fused_attn_fwd_qkvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias); const Tensor *input_Bias = reinterpret_cast<const Tensor*>(Bias);
...@@ -272,6 +287,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -272,6 +287,7 @@ void nvte_fused_attn_fwd_qkvpacked(
input_QKV, input_Bias, output_O, input_QKV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
...@@ -306,6 +322,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -306,6 +322,10 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dQKV, NVTETensor dQKV,
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen, size_t max_seqlen,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -316,6 +336,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -316,6 +336,10 @@ void nvte_fused_attn_bwd_qkvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor*>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor*>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO); const Tensor *input_dO = reinterpret_cast<const Tensor*>(dO);
...@@ -377,7 +401,9 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -377,7 +401,9 @@ void nvte_fused_attn_bwd_qkvpacked(
input_QKV, input_O, input_dO, input_Bias, input_QKV, input_O, input_dO, input_Bias,
output_S, output_S,
output_dQKV, output_dBias, output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
...@@ -417,6 +443,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -417,6 +443,10 @@ void nvte_fused_attn_fwd_kvpacked(
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -428,6 +458,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -428,6 +458,10 @@ void nvte_fused_attn_fwd_kvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
...@@ -482,6 +516,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -482,6 +516,7 @@ void nvte_fused_attn_fwd_kvpacked(
input_Q, input_KV, input_Bias, output_O, input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
...@@ -519,6 +554,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -519,6 +554,10 @@ void nvte_fused_attn_bwd_kvpacked(
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -529,6 +568,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -529,6 +568,10 @@ void nvte_fused_attn_bwd_kvpacked(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor*>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor*>(O); const Tensor *input_O = reinterpret_cast<const Tensor*>(O);
...@@ -596,6 +639,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -596,6 +639,7 @@ void nvte_fused_attn_bwd_kvpacked(
output_S, output_S,
output_dQ, output_dKV, output_dBias, output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle); input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
...@@ -636,6 +680,10 @@ void nvte_fused_attn_fwd( ...@@ -636,6 +680,10 @@ void nvte_fused_attn_fwd(
NVTETensorPack* Aux_CTX_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
bool is_training, float attn_scale, float dropout, bool is_training, float attn_scale, float dropout,
...@@ -647,6 +695,10 @@ void nvte_fused_attn_fwd( ...@@ -647,6 +695,10 @@ void nvte_fused_attn_fwd(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K); const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
...@@ -693,6 +745,7 @@ void nvte_fused_attn_fwd( ...@@ -693,6 +745,7 @@ void nvte_fused_attn_fwd(
input_Q, input_K, input_V, input_Bias, output_O, input_Q, input_K, input_V, input_Bias, output_O,
Aux_CTX_Tensors, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
...@@ -732,6 +785,10 @@ void nvte_fused_attn_bwd( ...@@ -732,6 +785,10 @@ void nvte_fused_attn_bwd(
NVTETensor dBias, NVTETensor dBias,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
...@@ -742,6 +799,10 @@ void nvte_fused_attn_bwd( ...@@ -742,6 +799,10 @@ void nvte_fused_attn_bwd(
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor*>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor*>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor*>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor*>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor*>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor*>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor*>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor*>(K); const Tensor *input_K = reinterpret_cast<const Tensor*>(K);
const Tensor *input_V = reinterpret_cast<const Tensor*>(V); const Tensor *input_V = reinterpret_cast<const Tensor*>(V);
...@@ -802,6 +863,7 @@ void nvte_fused_attn_bwd( ...@@ -802,6 +863,7 @@ void nvte_fused_attn_bwd(
output_S, output_S,
output_dQ, output_dK, output_dV, output_dBias, output_dQ, output_dK, output_dV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle); input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
......
...@@ -24,8 +24,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -24,8 +24,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t batch, size_t num_attn_heads, size_t max_seqlen,
...@@ -35,8 +37,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -35,8 +37,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
...@@ -47,7 +51,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -47,7 +51,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
...@@ -59,7 +64,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -59,7 +64,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
...@@ -72,7 +78,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -72,7 +78,8 @@ void fused_attn_arbitrary_seqlen_fwd(
const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
...@@ -86,7 +93,8 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -86,7 +93,8 @@ void fused_attn_arbitrary_seqlen_bwd(
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dV, Tensor *output_dBias,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *rng_state, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
......
...@@ -1239,27 +1239,37 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -1239,27 +1239,37 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen); assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
mask_type, query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd_kvpacked(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
is_training, scaling_factor, dropout_probability, qkv_layout, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
bias_type, mask_type, query_workspace_tensor.data(), nullptr); dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
...@@ -1294,6 +1304,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1294,6 +1304,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen); assert(q_max_seqlen == kv_max_seqlen);
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
...@@ -1304,8 +1316,10 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1304,8 +1316,10 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -1319,9 +1333,12 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1319,9 +1333,12 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -1336,11 +1353,15 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1336,11 +1353,15 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), &aux_input_tensors,
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, dbias_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
} }
...@@ -1416,6 +1437,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1416,6 +1437,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size}, auto workspace_tensor = TensorWrapper(workspace, std::vector<size_t>{descriptor.wkspace_size},
descriptor.wkspace_dtype); descriptor.wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0]; auto qkv = buffers[0];
...@@ -1423,9 +1446,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1423,9 +1446,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen, &aux_output_tensors, q_cu_seqlens_tensor.data(),
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
bias_type, mask_type, workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training,
descriptor.scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1437,9 +1463,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1437,9 +1463,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1453,9 +1481,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -1453,9 +1481,12 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, descriptor.is_training, scaling_factor,
bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -1496,15 +1527,20 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -1496,15 +1527,20 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors,
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dbias_tensor.data(),
dropout_probability, qkv_layout, bias_type, mask_type, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
query_workspace_tensor.data(), nullptr); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); auto work_shape = MakeShapeVector(query_workspace_tensor.shape());
return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype());
...@@ -1574,6 +1610,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1574,6 +1610,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto wkspace_dtype = descriptor.wkspace_dtype; auto wkspace_dtype = descriptor.wkspace_dtype;
auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype);
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{input_batch + 1}, DType::kInt32);
/* Call the underly NVTE API */ /* Call the underly NVTE API */
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
auto qkv = buffers[0]; auto qkv = buffers[0];
...@@ -1586,8 +1624,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1586,8 +1624,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1605,9 +1645,11 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1605,9 +1645,11 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -1629,10 +1671,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -1629,10 +1671,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(),
dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
workspace_tensor.data(), stream); dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -647,10 +647,13 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -647,10 +647,13 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
...@@ -664,7 +667,9 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -664,7 +667,9 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
...@@ -730,10 +735,13 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -730,10 +735,13 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace // allocate memory for workspace
...@@ -743,7 +751,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -743,7 +751,9 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), max_seqlen, attn_scale, p_dropout, te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers // destroy tensor wrappers
...@@ -816,10 +826,13 @@ void te_fused_attn_fwd_kvpacked( ...@@ -816,10 +826,13 @@ void te_fused_attn_fwd_kvpacked(
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -834,7 +847,9 @@ void te_fused_attn_fwd_kvpacked( ...@@ -834,7 +847,9 @@ void te_fused_attn_fwd_kvpacked(
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -909,11 +924,14 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -909,11 +924,14 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace // allocate memory for workspace
...@@ -924,7 +942,9 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -924,7 +942,9 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
...@@ -989,10 +1009,13 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -989,10 +1009,13 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream()); attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -1008,7 +1031,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1008,7 +1031,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// execute the kernel // execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream()); attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -1084,11 +1109,14 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1084,11 +1109,14 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
// create workspace // create workspace
TensorWrapper workspace; TensorWrapper workspace;
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream()); Q.stream());
...@@ -1100,7 +1128,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1100,7 +1128,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream()); Q.stream());
......
This diff is collapsed.
...@@ -83,6 +83,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -83,6 +83,10 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -118,6 +122,14 @@ def fused_attn_fwd_qkvpacked( ...@@ -118,6 +122,14 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -225,6 +237,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -225,6 +237,7 @@ def fused_attn_fwd_qkvpacked(
max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, qkv_dtype, cu_seqlens, qkv, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias,
rng_gen, rng_elts_per_thread, rng_gen, rng_elts_per_thread,
) )
...@@ -243,6 +256,10 @@ def fused_attn_bwd_qkvpacked( ...@@ -243,6 +256,10 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -286,6 +303,14 @@ def fused_attn_bwd_qkvpacked( ...@@ -286,6 +303,14 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -361,6 +386,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -361,6 +386,7 @@ def fused_attn_bwd_qkvpacked(
max_seqlen, attn_scale, dropout, fast_zero_fill, max_seqlen, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -379,6 +405,10 @@ def fused_attn_fwd_kvpacked( ...@@ -379,6 +405,10 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -421,6 +451,14 @@ def fused_attn_fwd_kvpacked( ...@@ -421,6 +451,14 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -529,6 +567,7 @@ def fused_attn_fwd_kvpacked( ...@@ -529,6 +567,7 @@ def fused_attn_fwd_kvpacked(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -550,6 +589,10 @@ def fused_attn_bwd_kvpacked( ...@@ -550,6 +589,10 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -600,6 +643,14 @@ def fused_attn_bwd_kvpacked( ...@@ -600,6 +643,14 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -679,6 +730,7 @@ def fused_attn_bwd_kvpacked( ...@@ -679,6 +730,7 @@ def fused_attn_bwd_kvpacked(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
...@@ -698,6 +750,10 @@ def fused_attn_fwd( ...@@ -698,6 +750,10 @@ def fused_attn_fwd(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -744,6 +800,14 @@ def fused_attn_fwd( ...@@ -744,6 +800,14 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -854,6 +918,7 @@ def fused_attn_fwd( ...@@ -854,6 +918,7 @@ def fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o,
attn_bias, rng_gen, rng_elts_per_thread, attn_bias, rng_gen, rng_elts_per_thread,
) )
...@@ -876,6 +941,10 @@ def fused_attn_bwd( ...@@ -876,6 +941,10 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -929,6 +998,14 @@ def fused_attn_bwd( ...@@ -929,6 +998,14 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -1012,6 +1089,7 @@ def fused_attn_bwd( ...@@ -1012,6 +1089,7 @@ def fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill,
QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type],
cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors,
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o,
d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp,
q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv,
) )
......
...@@ -31,6 +31,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -31,6 +31,10 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -54,6 +58,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -54,6 +58,10 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -76,6 +84,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -76,6 +84,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
const at::Tensor Q, const at::Tensor Q,
const at::Tensor KV, const at::Tensor KV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -101,6 +113,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -101,6 +113,10 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
...@@ -124,6 +140,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -124,6 +140,10 @@ std::vector<at::Tensor> fused_attn_fwd(
const at::Tensor K, const at::Tensor K,
const at::Tensor V, const at::Tensor V,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S,
...@@ -150,6 +170,10 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -150,6 +170,10 @@ std::vector<at::Tensor> fused_attn_bwd(
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
const transformer_engine::DType dqkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_O,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment