Unverified Commit 70d3251f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Simplify THD offset tensors (#927)



* simplify offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; tests pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix C lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_offset with with_padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_padding with padded
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for fused attn fwd/bwd calls
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust spacing in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pytorch tests; fix paddle api
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn_biases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix AttnFuncWithCP backward
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn with CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 94a426b0
...@@ -832,24 +832,6 @@ def _run_dot_product_attention( ...@@ -832,24 +832,6 @@ def _run_dot_product_attention(
inp[i].requires_grad = True inp[i].requires_grad = True
inp_orig[i].requires_grad = True inp_orig[i].requires_grad = True
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient # Create output gradient
qkv_format_kv = "_".join(qkv_format) qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq") qkv_format_kv = qkv_format_kv.replace("s", "sq")
...@@ -928,10 +910,8 @@ def _run_dot_product_attention( ...@@ -928,10 +910,8 @@ def _run_dot_product_attention(
max_seqlen_kv=config.max_seqlen_kv, max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
seq_offsets_k=seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn, checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
...@@ -1957,8 +1937,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -1957,8 +1937,6 @@ class _custom_mha_fp8(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -2038,8 +2016,6 @@ class _custom_mha_fp8(torch.autograd.Function): ...@@ -2038,8 +2016,6 @@ class _custom_mha_fp8(torch.autograd.Function):
FusedAttnBackend["FP8"], FusedAttnBackend["FP8"],
None, None,
None, None,
None,
None,
fwd_scale_inverses[META_QKV], # d_scale_qkv, fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s, fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o, fwd_scale_inverses[META_O], # d_scale_o,
......
...@@ -8,9 +8,11 @@ import subprocess ...@@ -8,9 +8,11 @@ import subprocess
from test_fused_attn import ( from test_fused_attn import (
ModelConfig, ModelConfig,
_is_flash_attention_2_available, _is_flash_attention_2_available,
_cudnn_version,
) )
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
model_configs_flash_attn = { model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias # test: b, h, hg, d, sq, skv, p, mask, bias
...@@ -58,7 +60,7 @@ model_configs_fused_attn = { ...@@ -58,7 +60,7 @@ model_configs_fused_attn = {
} }
@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
......
...@@ -196,21 +196,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -196,21 +196,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV // NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
const NVTETensor seq_offsets_o, const NVTETensor rng_state, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
...@@ -252,8 +247,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -252,8 +247,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -272,21 +266,19 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -272,21 +266,19 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} }
} }
// NVTE fused attention BWD with packed QKV // NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, const NVTETensor S, NVTETensor dP,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
...@@ -338,8 +330,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -338,8 +330,7 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...@@ -366,21 +357,18 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -366,21 +357,18 @@ void nvte_fused_attn_bwd_qkvpacked(
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
size_t max_seqlen_kv, bool is_training, float attn_scale, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
...@@ -426,8 +414,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -426,8 +414,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -450,18 +438,16 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -450,18 +438,16 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
...@@ -519,9 +505,9 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -519,9 +505,9 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
stream, handle); handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention " "cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...@@ -549,9 +535,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -549,9 +535,8 @@ void nvte_fused_attn_bwd_kvpacked(
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
...@@ -560,10 +545,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -560,10 +545,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
...@@ -601,8 +584,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -601,8 +584,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -625,20 +608,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -625,20 +608,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_kv, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V); const Tensor *input_V = reinterpret_cast<const Tensor *>(V);
...@@ -690,8 +670,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -690,8 +670,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
input_rng_state, wkspace, stream, handle); handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
...@@ -53,9 +53,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -53,9 +53,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsK, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
void *devPtrSeqOffsetsV, void *devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
...@@ -297,8 +297,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -297,8 +297,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto plan_workspace_size = mha_graph->get_workspace_size(); auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t);
if (workspace == nullptr) { if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size; *workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
return; return;
} }
...@@ -330,17 +332,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -330,17 +332,29 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
} }
if (is_ragged) { if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ; constexpr size_t nthreads_per_block = 128;
variant_pack[offset_k] = devPtrSeqOffsetsK; const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
variant_pack[offset_v] = devPtrSeqOffsetsV; void *devOffsetsQ =
variant_pack[offset_o] = devPtrSeqOffsetsO; static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + (b + 1) * sizeof(int32_t);
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + (b + 1) * sizeof(int32_t);
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO));
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
} }
if (is_dropout) { if (is_dropout) {
variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset; variant_pack[dropout_offset] = devPtrDropoutOffset;
} }
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) { } catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what()); NVTE_ERROR(e.what());
...@@ -354,9 +368,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -354,9 +368,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsK, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
void *devPtrSeqOffsetsV, void *devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
...@@ -366,9 +380,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -366,9 +380,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
try { try {
FADescriptor_v1 descriptor{b, FADescriptor_v1 descriptor{b,
...@@ -646,8 +657,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -646,8 +657,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t);
if (workspace == nullptr) { if (workspace == nullptr) {
*workspace_size = plan_workspace_size + actual_seqlen_workspace_size; *workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
return; return;
} }
...@@ -692,10 +705,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -692,10 +705,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
} }
if (is_ragged) { if (is_ragged) {
variant_pack[offset_q] = devPtrSeqOffsetsQ; constexpr size_t nthreads_per_block = 128;
variant_pack[offset_k] = devPtrSeqOffsetsK; const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
variant_pack[offset_v] = devPtrSeqOffsetsV; void *devOffsetsQ =
variant_pack[offset_o] = devPtrSeqOffsetsO; static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + (b + 1) * sizeof(int32_t);
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + (b + 1) * sizeof(int32_t);
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO));
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
} }
if (is_dropout) { if (is_dropout) {
...@@ -715,8 +741,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -715,8 +741,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -744,10 +769,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -744,10 +769,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void *devPtrO = output_O->data.dptr; void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr; void *devPtrS = nullptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
...@@ -801,9 +823,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -801,9 +823,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -825,8 +846,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -825,8 +846,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -866,10 +886,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -866,10 +886,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
devPtrSoftmaxStats = output_S->data.dptr; devPtrSoftmaxStats = output_S->data.dptr;
void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr; void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset = void *devPtrDropoutOffset =
...@@ -881,9 +898,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -881,9 +898,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsetsQ, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets,
devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
workspace->data.dptr, &workspace_size, stream, handle); handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -905,9 +922,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -905,9 +922,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -936,10 +952,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -936,10 +952,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
...@@ -993,9 +1007,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -993,9 +1007,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1019,9 +1032,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1019,9 +1032,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -1060,10 +1072,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1060,10 +1072,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr; void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset = void *devPtrDropoutOffset =
...@@ -1076,8 +1086,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1076,8 +1086,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1094,15 +1104,17 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1094,15 +1104,17 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
} }
} }
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, NVTE_Mask_Type mask_type, const Tensor *input_Q,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *input_K, const Tensor *input_V,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *input_Bias, Tensor *output_O,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
...@@ -1122,10 +1134,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1122,10 +1134,8 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
...@@ -1179,9 +1189,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1179,9 +1189,8 @@ void fused_attn_arbitrary_seqlen_fwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1205,9 +1214,9 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1205,9 +1214,9 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
...@@ -1234,10 +1243,8 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1234,10 +1243,8 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr;
void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr;
void *devPtrDropoutSeed = rng_state->data.dptr; void *devPtrDropoutSeed = rng_state->data.dptr;
void *devPtrDropoutOffset = void *devPtrDropoutOffset =
...@@ -1250,8 +1257,8 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1250,8 +1257,8 @@ void fused_attn_arbitrary_seqlen_bwd(
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrO,
devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, devPtrSeqOffsetsO, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
......
...@@ -22,8 +22,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -22,8 +22,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
...@@ -31,8 +30,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -31,8 +30,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...@@ -41,9 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -41,9 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
...@@ -52,20 +49,21 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -52,20 +49,21 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout,
...@@ -73,9 +71,9 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -73,9 +71,9 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -360,6 +360,38 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu ...@@ -360,6 +360,38 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
} }
} }
// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d * cu_seqlens_q_padded[tid];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
}
}
}
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type // get cuDNN data type
......
...@@ -121,6 +121,11 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu ...@@ -121,6 +121,11 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens); int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o);
} // namespace fused_attn } // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
...@@ -166,21 +166,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -166,21 +166,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* *
* Notes: * Notes:
* *
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o` * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences
* help identify the correct offsets of different sequences in tensors Q, K, V and O. * in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`.
* When the QKV format is `thd`, these tensors should follow the following rules. * When the QKV format is `thd`, this tensor should follow the following rules.
* When there is no padding between sequences, the offset tensors are, * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed. * When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`. * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
* *
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
...@@ -189,10 +183,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -189,10 +183,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1. * it may be >= max(seqlen_i) for i=0,...batch_size-1.
...@@ -207,11 +198,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -207,11 +198,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*/ */
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
const NVTETensor seq_offsets_o, const NVTETensor rng_state, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
...@@ -227,21 +216,15 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -227,21 +216,15 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* *
* Notes: * Notes:
* *
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o` * Tensor `cu_seqlens_padded` helps identify the correct offsets of different sequences
* help identify the correct offsets of different sequences in tensors Q, K, V and O. * in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. * the offset tensor is not used in the attention calculation and can be set to empty `NVTETensor`.
* When the QKV format is `thd`, these tensors should follow the following rules. * When the QKV format is `thd`, this tensor should follow the following rules.
* When there is no padding between sequences, the offset tensors are, * When there is no padding between sequences, the offset tensor should be equal to `cu_seqlens`,
\verbatim
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed. * When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`. * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
* *
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
...@@ -253,10 +236,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -253,10 +236,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[out] dQKV The gradient of the QKV tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1].
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing, * \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1. * it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
...@@ -267,13 +247,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -267,13 +247,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, const NVTETensor S, NVTETensor dP,
NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream); NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
* *
...@@ -292,57 +273,49 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -292,57 +273,49 @@ void nvte_fused_attn_bwd_qkvpacked(
* *
* Notes: * Notes:
* *
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o` * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O. * help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules. * When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are, * When there is no padding between sequences, the offset tensors should be equal to
\verbatim * `cu_seqlens_q` and `cu_seqlens_kv` respectively.
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed. * When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`. * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
* *
* \param[in] Q The Q tensor, in HD layouts. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] rng_state Seed and offset of CUDA random number generator. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] is_training Whether this is in training mode or inference.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] dropout Dropout probability.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] dropout Dropout probability. * \param[in] bias_type Bias type.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] attn_mask_type Attention mask type.
* \param[in] bias_type Bias type. * \param[in] workspace Workspace tensor.
* \param[in] attn_mask_type Attention mask type. * \param[in] stream CUDA stream used for this operation.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
size_t max_seqlen_kv, bool is_training, float attn_scale, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
...@@ -357,59 +330,52 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -357,59 +330,52 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
* *
* Notes: * Notes:
* *
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o` * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O. * help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules. * When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are, * When there is no padding between sequences, the offset tensors should be equal to
\verbatim * `cu_seqlens_q` and `cu_seqlens_kv` respectively.
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed. * When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`. * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
* *
* \param[in] Q The Q tensor, in HD layouts. * \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in H2D or 2HD layouts. * \param[in] KV The KV tensor, in H2D or 2HD layouts.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV. * \param[in] attn_scale Scaling factor for Q * K.T.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] dropout Dropout probability.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] dropout Dropout probability. * \param[in] bias_type Bias type.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] attn_mask_type Attention mask type.
* \param[in] bias_type Bias type. * \param[in] workspace Workspace tensor.
* \param[in] attn_mask_type Attention mask type. * \param[in] stream CUDA stream used for this operation.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -431,66 +397,48 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -431,66 +397,48 @@ void nvte_fused_attn_bwd_kvpacked(
* *
* Notes: * Notes:
* *
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o` * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O. * help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules. * When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are, * When there is no padding between sequences, the offset tensors should be equal to
\verbatim * `cu_seqlens_q` and `cu_seqlens_kv` respectively.
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed. * When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`. * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
* *
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
* \param[in] K The K tensor. * \param[in] K The K tensor.
* \param[in] V The V tensor. * \param[in] V The V tensor.
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] rng_state Seed and offset of CUDA random number generator. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * \param[in] is_training Whether this is in training mode or inference.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] dropout Dropout probability.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] dropout Dropout probability. * \param[in] bias_type Bias type.
* \param[in] qkv_layout QKV tensors' layout. * \param[in] attn_mask_type Attention mask type.
* \param[in] bias_type Bias type. * \param[in] workspace Workspace tensor.
* \param[in] attn_mask_type Attention mask type. * \param[in] stream CUDA stream used for this operation.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
...@@ -510,73 +458,56 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -510,73 +458,56 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* *
* Notes: * Notes:
* *
* Tensors `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` and `seq_offsets_o` * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded`
* help identify the correct offsets of different sequences in tensors Q, K, V and O. * help identify the correct offsets of different sequences in tensors Q, K, V and O.
* When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`, * When the QKV format (`nvte_get_qkv_format(qkv_layout)`) is `bshd` or `sbhd`,
* offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s. * offset tensors are not used in the attention calculation and can be set to empty `NVTETensor`s.
* When the QKV format is `thd`, these tensors should follow the following rules. * When the QKV format is `thd`, these tensors should follow the following rules.
* When there is no padding between sequences, the offset tensors are, * When there is no padding between sequences, the offset tensors should be equal to
\verbatim * `cu_seqlens_q` and `cu_seqlens_kv` respectively.
qkv_group = nvte_get_qkv_layout_group(qkv_layout)
if qkv_group == 'hd_hd_hd':
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ['3hd', 'h3d']:
seq_offsets_q = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_attn_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_attn_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ['hd_2hd', 'hd_h2d']:
seq_offsets_q = num_attn_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_o = num_attn_heads * head_dim * cu_seqlens_q
\endverbatim
* When there is padding between sequences, users are responsible to adjust the offsets as needed. * When there is padding between sequences, users are responsible to adjust the offsets as needed.
* For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have * For example, a tensor of 4 sequences `[a, PAD, b, b, c, PAD, PAD, d, d]` should have
* `cu_seqlens = [0, 1, 3, 4, 6]` and `seq_offsets = [0, 2, 4, 7, 9]`. * `cu_seqlens = [0, 1, 3, 4, 6]` and `cu_seqlens_padded= [0, 2, 4, 7, 9]`.
* *
* \param[in] Q The Q tensor. * \param[in] Q The Q tensor.
* \param[in] K The K tensor. * \param[in] K The K tensor.
* \param[in] V The V tensor. * \param[in] V The V tensor.
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state. * e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dK The gradient of the K tensor. * \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor. * \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * \param[in] attn_scale Scaling factor for Q * K.T.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] dropout Dropout probability.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] qkv_layout QKV tensors' layout.
* \param[in] dropout Dropout probability. * \param[in] bias_type Bias type.
* \param[in] qkv_layout QKV tensors' layout. * \param[in] attn_mask_type Attention mask type.
* \param[in] bias_type Bias type. * \param[in] workspace Workspace tensor.
* \param[in] attn_mask_type Attention mask type. * \param[in] stream CUDA stream used for this operation.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, size_t max_seqlen_kv, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream);
NVTETensor workspace, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -143,19 +143,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -143,19 +143,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper query_workspace_tensor; TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen); assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, is_training, scaling_factor, dropout_probability,
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, nullptr);
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr); nullptr);
...@@ -164,7 +162,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -164,7 +162,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
...@@ -208,15 +205,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -208,15 +205,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, bias_type, mask_type, query_workspace_tensor.data(), nullptr);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -230,7 +225,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -230,7 +225,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(), dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr); nullptr);
...@@ -251,7 +245,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -251,7 +245,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
...@@ -340,10 +333,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -340,10 +333,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, descriptor.scaling_factor,
dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen, dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -355,7 +346,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -355,7 +346,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
...@@ -373,7 +363,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -373,7 +363,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
s_tensor.data(), o_tensor.data(), &aux_output_tensors, s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream); bias_type, mask_type, workspace_tensor.data(), stream);
...@@ -426,7 +415,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -426,7 +415,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr); bias_type, mask_type, query_workspace_tensor.data(), nullptr);
...@@ -507,15 +495,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -507,15 +495,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10]; auto dqkv = buffers[10];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, bias_type, mask_type, workspace_tensor.data(), stream);
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0]; auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
...@@ -533,7 +519,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -533,7 +519,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
...@@ -559,7 +544,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -559,7 +544,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
......
...@@ -640,12 +640,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -640,12 +640,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
...@@ -655,12 +654,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -655,12 +654,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
output_s->data.dptr = GetOptionalDataPtr(softmax_aux); output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, attn_mask_type_enum, workspace.data(), QKV.stream());
workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -727,24 +725,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -727,24 +725,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), QKV.stream());
QKV.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), QKV.stream());
QKV.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -818,12 +814,12 @@ void te_fused_attn_fwd_kvpacked( ...@@ -818,12 +814,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32); auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -833,12 +829,12 @@ void te_fused_attn_fwd_kvpacked( ...@@ -833,12 +829,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux); output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -916,7 +912,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -916,7 +912,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -929,7 +924,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -929,7 +924,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream()); bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
...@@ -1001,10 +995,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1001,10 +995,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), workspace.data(), Q.stream());
Q.stream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -1018,10 +1011,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1018,10 +1011,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), workspace.data(), Q.stream());
Q.stream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1100,10 +1092,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1100,10 +1092,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), Q.stream());
Q.stream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
...@@ -1113,10 +1104,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p ...@@ -1113,10 +1104,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), attn_mask_type_enum, workspace.data(), Q.stream());
Q.stream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
...@@ -550,10 +550,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -550,10 +550,8 @@ class AttnFuncWithCP(torch.autograd.Function):
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
dropout_p, dropout_p,
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
...@@ -694,10 +692,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -694,10 +692,8 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2], attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k=seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
) )
) )
if len(rest) > 0: if len(rest) > 0:
...@@ -769,14 +765,12 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -769,14 +765,12 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type="padding" if padding else "no_mask", attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2], attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k=( cu_seqlens_kv_padded=(
None if seq_offsets_k is None else seq_offsets_k // 2 None
if cu_seqlens_kv_padded is None
else cu_seqlens_kv_padded // 2
), ),
seq_offsets_v=(
None if seq_offsets_v is None else seq_offsets_v // 2
),
seq_offsets_o=seq_offsets_o,
) )
) )
if len(rest) > 0: if len(rest) > 0:
...@@ -863,14 +857,12 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -863,14 +857,12 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type="padding" if padding else "no_mask", attn_mask_type="padding" if padding else "no_mask",
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2], attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=( cu_seqlens_q_padded=(
None if seq_offsets_q is None else seq_offsets_q // 2 None
), if cu_seqlens_q_padded is None
seq_offsets_k=seq_offsets_k, else cu_seqlens_q_padded // 2
seq_offsets_v=seq_offsets_v,
seq_offsets_o=(
None if seq_offsets_o is None else seq_offsets_o // 2
), ),
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
) )
) )
if len(rest) > 0: if len(rest) > 0:
...@@ -940,10 +932,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -940,10 +932,8 @@ class AttnFuncWithCP(torch.autograd.Function):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_bias=attn_bias_inputs[i % 2], attn_bias=attn_bias_inputs[i % 2],
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k=seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
) )
) )
if len(rest) > 0: if len(rest) > 0:
...@@ -1082,10 +1072,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1082,10 +1072,8 @@ class AttnFuncWithCP(torch.autograd.Function):
softmax_lse, softmax_lse,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
*rng_states, *rng_states,
*attn_biases, *attn_biases,
) )
...@@ -1106,10 +1094,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1106,10 +1094,10 @@ class AttnFuncWithCP(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout):
(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6] (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6]
(seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) = ctx.saved_tensors[6:10] (cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[6:8]
cp_size = get_distributed_world_size(ctx.cp_group) cp_size = get_distributed_world_size(ctx.cp_group)
rng_states = ctx.saved_tensors[10 : 10 + cp_size] rng_states = ctx.saved_tensors[8 : 8 + cp_size]
attn_biases = ctx.saved_tensors[10 + cp_size : 10 + cp_size * 2] attn_biases = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2]
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size]
...@@ -1224,10 +1212,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1224,10 +1212,8 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype], TE_DType[kv.dtype],
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p, dropout=ctx.dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1305,10 +1291,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1305,10 +1291,10 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype], TE_DType[kv.dtype],
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
None if seq_offsets_k is None else seq_offsets_k // 2, cu_seqlens_kv_padded=(
None if seq_offsets_v is None else seq_offsets_v // 2, None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2
seq_offsets_o, ),
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p, dropout=ctx.dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1392,10 +1378,10 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1392,10 +1378,10 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype], TE_DType[kv.dtype],
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
None if seq_offsets_q is None else seq_offsets_q // 2, cu_seqlens_q_padded=(
seq_offsets_k, None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2
seq_offsets_v, ),
None if seq_offsets_o is None else seq_offsets_o // 2, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p, dropout=ctx.dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1461,10 +1447,8 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1461,10 +1447,8 @@ class AttnFuncWithCP(torch.autograd.Function):
TE_DType[kv.dtype], TE_DType[kv.dtype],
aux_ctx_tensors, aux_ctx_tensors,
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
attn_scale=ctx.softmax_scale, attn_scale=ctx.softmax_scale,
dropout=ctx.dropout_p, dropout=ctx.dropout_p,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1658,8 +1642,6 @@ class AttnFuncWithCP(torch.autograd.Function): ...@@ -1658,8 +1642,6 @@ class AttnFuncWithCP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
attn_dbias, attn_dbias,
None, None,
None, None,
...@@ -1675,10 +1657,8 @@ def attn_forward_func_with_cp( ...@@ -1675,10 +1657,8 @@ def attn_forward_func_with_cp(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
dropout_p, dropout_p,
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
...@@ -1721,10 +1701,8 @@ def attn_forward_func_with_cp( ...@@ -1721,10 +1701,8 @@ def attn_forward_func_with_cp(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
dropout_p, dropout_p,
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
...@@ -2593,8 +2571,6 @@ class FlashAttention(torch.nn.Module): ...@@ -2593,8 +2571,6 @@ class FlashAttention(torch.nn.Module):
max_seqlen_kv, max_seqlen_kv,
None, None,
None, None,
None,
None,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
...@@ -2690,10 +2666,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2690,10 +2666,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
is_training, is_training,
max_seqlen, max_seqlen,
cu_seqlens, cu_seqlens,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
qkv, qkv,
qkv_dtype, qkv_dtype,
attn_bias, attn_bias,
...@@ -2738,10 +2711,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2738,10 +2711,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -2806,10 +2776,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2806,10 +2776,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv_dtype, qkv_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
None, None,
None, None,
None, None,
...@@ -2830,14 +2797,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2830,14 +2797,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward( ctx.save_for_backward(
*qkvo_tensors, *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors
cu_seqlens,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
*fp8_tensors,
*aux_ctx_tensors,
) )
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
...@@ -2870,10 +2830,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2870,10 +2830,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkv, qkv,
out, out,
cu_seqlens, cu_seqlens,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
qkv_fp8, qkv_fp8,
out_fp8, out_fp8,
fwd_scales, fwd_scales,
...@@ -2939,10 +2896,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2939,10 +2896,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o, fwd_scale_invs[META_O], # d_scale_o,
...@@ -2994,10 +2948,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2994,10 +2948,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
ctx.qkv_dtype, ctx.qkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
None, None,
None, None,
None, None,
...@@ -3019,9 +2970,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3019,9 +2970,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
# if no_bias or alibi, return dqkv # if no_bias or alibi, return dqkv
if ctx.attn_bias_type in ["no_bias", "alibi"]: if ctx.attn_bias_type in ["no_bias", "alibi"]:
return ( return (
None,
None,
None,
None, None,
None, None,
None, None,
...@@ -3045,9 +2993,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -3045,9 +2993,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
) )
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return ( return (
None,
None,
None,
None, None,
None, None,
None, None,
...@@ -3082,10 +3027,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3082,10 +3027,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
max_seqlen_kv, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
q, q,
kv, kv,
qkv_dtype, qkv_dtype,
...@@ -3139,10 +3082,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3139,10 +3082,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -3214,10 +3155,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3214,10 +3155,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkv_dtype, qkv_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
None, None,
None, None,
None, None,
...@@ -3241,10 +3180,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3241,10 +3180,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
*qkvo_tensors, *qkvo_tensors,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
*fp8_tensors, *fp8_tensors,
*aux_ctx_tensors, *aux_ctx_tensors,
) )
...@@ -3282,10 +3219,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3282,10 +3219,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
out, out,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
q_fp8, q_fp8,
kv_fp8, kv_fp8,
out_fp8, out_fp8,
...@@ -3355,10 +3290,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3355,10 +3290,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o, fwd_scale_invs[META_O], # d_scale_o,
...@@ -3428,10 +3361,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3428,10 +3361,8 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
ctx.qkv_dtype, ctx.qkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
None, None,
None, None,
None, None,
...@@ -3460,8 +3391,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3460,8 +3391,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
dq, dq,
dkv, dkv,
None, None,
...@@ -3489,8 +3418,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -3489,8 +3418,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
dq, dq,
dkv, dkv,
None, None,
...@@ -3522,10 +3449,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3522,10 +3449,8 @@ class FusedAttnFunc(torch.autograd.Function):
max_seqlen_kv, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
q, q,
k, k,
v, v,
...@@ -3602,10 +3527,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3602,10 +3527,8 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_S],
...@@ -3727,10 +3650,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3727,10 +3650,8 @@ class FusedAttnFunc(torch.autograd.Function):
qkv_dtype, qkv_dtype,
fused_attention_backend, fused_attention_backend,
attn_bias, attn_bias,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
None, None,
None, None,
None, None,
...@@ -3763,10 +3684,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3763,10 +3684,8 @@ class FusedAttnFunc(torch.autograd.Function):
*qkvo_tensors, *qkvo_tensors,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
*fp8_tensors, *fp8_tensors,
*aux_ctx_tensors, *aux_ctx_tensors,
) )
...@@ -3805,10 +3724,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3805,10 +3724,8 @@ class FusedAttnFunc(torch.autograd.Function):
out, out,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
q_fp8, q_fp8,
k_fp8, k_fp8,
v_fp8, v_fp8,
...@@ -3882,10 +3799,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3882,10 +3799,8 @@ class FusedAttnFunc(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_QKV], # d_scale_qkv,
fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_S], # d_scale_s,
fwd_scale_invs[META_O], # d_scale_o, fwd_scale_invs[META_O], # d_scale_o,
...@@ -3903,6 +3818,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -3903,6 +3818,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_bias_type, ctx.attn_bias_type,
ctx.attn_mask_type, ctx.attn_mask_type,
) )
if ctx.fp8_meta["recipe"].fp8_mha: if ctx.fp8_meta["recipe"].fp8_mha:
dq = Float8Tensor( dq = Float8Tensor(
data=dq_fp8, data=dq_fp8,
...@@ -4007,10 +3923,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4007,10 +3923,8 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.qkv_dtype, ctx.qkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
ctx.fused_attention_backend, ctx.fused_attention_backend,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
None, None,
None, None,
None, None,
...@@ -4039,8 +3953,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4039,8 +3953,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
dq, dq,
dk, dk,
dv, dv,
...@@ -4069,8 +3981,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4069,8 +3981,6 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
dq, dq,
dk, dk,
dv, dv,
...@@ -4186,10 +4096,8 @@ class FusedAttention(torch.nn.Module): ...@@ -4186,10 +4096,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None, cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
...@@ -4271,31 +4179,9 @@ class FusedAttention(torch.nn.Module): ...@@ -4271,31 +4179,9 @@ class FusedAttention(torch.nn.Module):
and cu_seqlens_q is not None and cu_seqlens_q is not None
and cu_seqlens_kv is not None and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if ( if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
seq_offsets_q is None cu_seqlens_q_padded = cu_seqlens_q
or seq_offsets_k is None cu_seqlens_kv_padded = cu_seqlens_kv
or seq_offsets_v is None
or seq_offsets_o is None
or context_parallel
):
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
qkv_group = "hd_hd_hd" if context_parallel else qkv_group
num_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
head_dim = query_layer.shape[-1]
seq_offsets_o = num_heads * head_dim * cu_seqlens_q
if qkv_group == "hd_hd_hd":
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q
seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = num_heads * head_dim * cu_seqlens_q
seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
...@@ -4325,10 +4211,8 @@ class FusedAttention(torch.nn.Module): ...@@ -4325,10 +4211,8 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_kv, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
cp_group, cp_group,
cp_global_ranks, cp_global_ranks,
...@@ -4356,10 +4240,8 @@ class FusedAttention(torch.nn.Module): ...@@ -4356,10 +4240,8 @@ class FusedAttention(torch.nn.Module):
max_seqlen_kv, max_seqlen_kv,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -4669,10 +4551,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4669,10 +4551,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: Optional[str] = None, qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
seq_offsets_q: Optional[torch.Tensor] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None,
seq_offsets_k: Optional[torch.Tensor] = None, cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
seq_offsets_v: Optional[torch.Tensor] = None,
seq_offsets_o: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
...@@ -4749,23 +4629,21 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4749,23 +4629,21 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: str, default = `None` qkv_format: str, default = `None`
If provided, overrides :attr:`qkv_format` from initialization. If provided, overrides :attr:`qkv_format` from initialization.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
seq_offsets_q: Optional[torch.Tensor], default = `None` cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative offset of different sequences in a batch for `query_layer`, Cumulative sum of sequence lengths (with offset) in a batch for
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. `query_layer`, with shape [batch_size + 1] and dtype torch.int32.
seq_offsets_k: Optional[torch.Tensor], default = `None` When there is no padding between sequences in a batch,
Cumulative offset of different sequences in a batch for `key_layer`, `cu_seqlens_q_padded = cu_seqlens_q`.
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
seq_offsets_v: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
Cumulative offset of different sequences in a batch for `value_layer`, and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. When there is no padding between sequences in a batch,
seq_offsets_o: Optional[torch.Tensor], default = `None` `cu_seqlens_kv_padded = cu_seqlens_kv`.
Cumulative offset of different sequences in a batch for forward output,
with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q` if not provided.
...@@ -4992,9 +4870,25 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -4992,9 +4870,25 @@ class DotProductAttention(TransformerEngineBaseModule):
# certain asserts before executing the forward pass. # certain asserts before executing the forward pass.
# Filter: QKV layout. # Filter: QKV layout.
if use_unfused_attention and qkv_format == "thd": if qkv_format == "thd":
self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd") if use_unfused_attention:
use_unfused_attention = False self.logger.debug("Disabling UnusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_fused_attention and (
(
cu_seqlens_q_padded is not None
and torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
)
or (
cu_seqlens_kv_padded is not None
and torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv)
)
):
self.logger.debug(
"Disabling FlashAttention for qkv_format = thd "
"when there is padding between sequences."
)
use_flash_attention = False
# Filter: ONNX export. # Filter: ONNX export.
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -5354,10 +5248,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5354,10 +5248,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k=seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
...@@ -5379,10 +5271,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5379,10 +5271,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q, cu_seqlens_q_padded=cu_seqlens_q_padded,
seq_offsets_k=seq_offsets_k, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
......
...@@ -85,10 +85,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -85,10 +85,7 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -124,14 +121,8 @@ def fused_attn_fwd_qkvpacked( ...@@ -124,14 +121,8 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for QKV; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -246,10 +237,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -246,10 +237,7 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens, cu_seqlens,
qkv, qkv,
qkv_dtype, qkv_dtype,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
q_scale_s, q_scale_s,
...@@ -275,10 +263,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -275,10 +263,7 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None, cu_seqlens_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -322,14 +307,8 @@ def fused_attn_bwd_qkvpacked( ...@@ -322,14 +307,8 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for QKV; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -419,10 +398,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -419,10 +398,7 @@ def fused_attn_bwd_qkvpacked(
qkv_dtype, qkv_dtype,
dqkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
seq_offsets_q, cu_seqlens_padded,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
d_scale_o, d_scale_o,
...@@ -449,10 +425,8 @@ def fused_attn_fwd_kvpacked( ...@@ -449,10 +425,8 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -495,14 +469,10 @@ def fused_attn_fwd_kvpacked( ...@@ -495,14 +469,10 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -621,10 +591,8 @@ def fused_attn_fwd_kvpacked( ...@@ -621,10 +591,8 @@ def fused_attn_fwd_kvpacked(
q, q,
kv, kv,
qkv_dtype, qkv_dtype,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
q_scale_s, q_scale_s,
...@@ -653,10 +621,8 @@ def fused_attn_bwd_kvpacked( ...@@ -653,10 +621,8 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -707,14 +673,10 @@ def fused_attn_bwd_kvpacked( ...@@ -707,14 +673,10 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -811,10 +773,8 @@ def fused_attn_bwd_kvpacked( ...@@ -811,10 +773,8 @@ def fused_attn_bwd_kvpacked(
qkv_dtype, qkv_dtype,
dqkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
d_scale_o, d_scale_o,
...@@ -842,10 +802,8 @@ def fused_attn_fwd( ...@@ -842,10 +802,8 @@ def fused_attn_fwd(
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None,
...@@ -892,14 +850,10 @@ def fused_attn_fwd( ...@@ -892,14 +850,10 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -1021,10 +975,8 @@ def fused_attn_fwd( ...@@ -1021,10 +975,8 @@ def fused_attn_fwd(
k, k,
v, v,
qkv_dtype, qkv_dtype,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
q_scale_s, q_scale_s,
...@@ -1054,10 +1006,8 @@ def fused_attn_bwd( ...@@ -1054,10 +1006,8 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType, dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor], aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None, d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None, d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None, d_scale_o: torch.Tensor = None,
...@@ -1111,14 +1061,10 @@ def fused_attn_bwd( ...@@ -1111,14 +1061,10 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state] e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends. please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None d_scale_s: torch.Tensor, default = None
...@@ -1220,10 +1166,8 @@ def fused_attn_bwd( ...@@ -1220,10 +1166,8 @@ def fused_attn_bwd(
qkv_dtype, qkv_dtype,
dqkv_dtype, dqkv_dtype,
aux_ctx_tensors, aux_ctx_tensors,
seq_offsets_q, cu_seqlens_q_padded,
seq_offsets_k, cu_seqlens_kv_padded,
seq_offsets_v,
seq_offsets_o,
d_scale_qkv, d_scale_qkv,
d_scale_s, d_scale_s,
d_scale_o, d_scale_o,
......
...@@ -26,22 +26,19 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -26,22 +26,19 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
...@@ -53,8 +50,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -53,8 +50,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
...@@ -67,27 +64,27 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -67,27 +64,27 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd( std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> seq_offsets_q, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd( std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -95,14 +92,14 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -95,14 +92,14 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dQKV); c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
...@@ -83,13 +83,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -83,13 +83,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto qkv_sizes = QKV.sizes().vec(); auto qkv_sizes = QKV.sizes().vec();
...@@ -107,8 +105,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -107,8 +105,7 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto O = torch::empty(o_shape, options); auto O = torch::empty(o_shape, options);
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens, te_cu_seqlens_padded;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -150,27 +147,12 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -150,27 +147,12 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && if (cu_seqlens_padded.has_value()) {
(seq_offsets_o.has_value())) { auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec();
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector<size_t> cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(),
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; cu_seqlens_padded_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); te_cu_seqlens_padded =
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape,
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
} }
...@@ -191,12 +173,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -191,12 +173,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -232,12 +213,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -232,12 +213,11 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
} }
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_cu_seqlens.data(),
te_cu_seqlens.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen,
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -252,9 +232,8 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -252,9 +232,8 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO, const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
...@@ -358,28 +337,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -358,28 +337,13 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper te_cu_seqlens = makeTransformerEngineTensor( TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(
cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; TensorWrapper te_cu_seqlens_padded;
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && if (cu_seqlens_padded.has_value()) {
(seq_offsets_o.has_value())) { auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec();
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector<size_t> cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(),
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; cu_seqlens_padded_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); te_cu_seqlens_padded =
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape,
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec();
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()};
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec();
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()};
te_seq_offsets_q =
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
} }
...@@ -387,12 +351,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -387,12 +351,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_seq_offsets_q.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen,
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -400,12 +363,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -400,12 +363,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_qkvpacked( nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, &nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_seq_offsets_q.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), max_seqlen,
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(), workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -419,8 +381,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -419,8 +381,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
...@@ -440,7 +402,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -440,7 +402,7 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -489,28 +451,19 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -489,28 +451,19 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
(seq_offsets_o.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; cu_seqlens_q_padded_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); cu_seqlens_kv_padded_sizes.end()};
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); cu_seqlens_q_padded_shape, DType::kInt32,
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; nullptr, nullptr, nullptr);
te_seq_offsets_q = te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, cu_seqlens_kv_padded_shape, DType::kInt32,
DType::kInt32, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
} }
// extract rng seed and offset // extract rng seed and offset
...@@ -532,10 +485,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -532,10 +485,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -573,10 +526,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -573,10 +526,10 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
// execute the kernel // execute the kernel
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, workspace.data(),
qkv_layout, bias_type, attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -592,14 +545,14 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -592,14 +545,14 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO, const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dQKV) { c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec(); auto q_sizes = Q.sizes().vec();
...@@ -689,29 +642,20 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -689,29 +642,20 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
(seq_offsets_o.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; cu_seqlens_q_padded_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); cu_seqlens_kv_padded_sizes.end()};
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); cu_seqlens_q_padded_shape, DType::kInt32,
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; nullptr, nullptr, nullptr);
te_seq_offsets_q = te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, cu_seqlens_kv_padded_shape, DType::kInt32,
DType::kInt32, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
} }
// convert auxiliary tensors from forward to NVTETensors // convert auxiliary tensors from forward to NVTETensors
...@@ -746,13 +690,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -746,13 +690,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, workspace.data(), at::cuda::getCurrentCUDAStream());
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -760,13 +703,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked( ...@@ -760,13 +703,12 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// execute kernel // execute kernel
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd_kvpacked(
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, workspace.data(), at::cuda::getCurrentCUDAStream());
attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -780,13 +722,13 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -780,13 +722,13 @@ std::vector<at::Tensor> fused_attn_fwd(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> seq_offsets_q, const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec(); auto q_sizes = Q.sizes().vec();
...@@ -802,7 +744,7 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -802,7 +744,7 @@ std::vector<at::Tensor> fused_attn_fwd(
// construct NVTE tensors // construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -853,28 +795,19 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -853,28 +795,19 @@ std::vector<at::Tensor> fused_attn_fwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
(seq_offsets_o.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; cu_seqlens_q_padded_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); cu_seqlens_kv_padded_sizes.end()};
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); cu_seqlens_q_padded_shape, DType::kInt32,
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; nullptr, nullptr, nullptr);
te_seq_offsets_q = te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, cu_seqlens_kv_padded_shape, DType::kInt32,
DType::kInt32, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
} }
// extract rng seed and offset // extract rng seed and offset
...@@ -897,11 +830,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -897,11 +830,10 @@ std::vector<at::Tensor> fused_attn_fwd(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
bias_type, attn_mask_type, workspace.data(), attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -939,11 +871,10 @@ std::vector<at::Tensor> fused_attn_fwd( ...@@ -939,11 +871,10 @@ std::vector<at::Tensor> fused_attn_fwd(
// execute the kernel // execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_seq_offsets_v.data(), te_seq_offsets_o.data(), te_rng_state.data(), te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
bias_type, attn_mask_type, workspace.data(), attn_mask_type, workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -959,14 +890,14 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -959,14 +890,14 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO, const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP, const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dQKV) { c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV) {
using namespace transformer_engine; using namespace transformer_engine;
auto q_sizes = Q.sizes().vec(); auto q_sizes = Q.sizes().vec();
...@@ -1131,29 +1062,20 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1131,29 +1062,20 @@ std::vector<at::Tensor> fused_attn_bwd(
te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape,
DType::kInt32, nullptr, nullptr, nullptr); DType::kInt32, nullptr, nullptr, nullptr);
TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value()) && if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) {
(seq_offsets_o.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec();
auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector<size_t> cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(),
std::vector<size_t> seq_offsets_q_shape{seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; cu_seqlens_q_padded_sizes.end()};
auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec();
std::vector<size_t> seq_offsets_k_shape{seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; std::vector<size_t> cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(),
auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); cu_seqlens_kv_padded_sizes.end()};
std::vector<size_t> seq_offsets_v_shape{seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(),
auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); cu_seqlens_q_padded_shape, DType::kInt32,
std::vector<size_t> seq_offsets_o_shape{seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; nullptr, nullptr, nullptr);
te_seq_offsets_q = te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(),
makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, cu_seqlens_kv_padded_shape, DType::kInt32,
DType::kInt32, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr);
te_seq_offsets_k =
makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_v =
makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_seq_offsets_o =
makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), seq_offsets_o_shape,
DType::kInt32, nullptr, nullptr, nullptr);
} }
// convert auxiliary tensors from forward to NVTETensors // convert auxiliary tensors from forward to NVTETensors
...@@ -1191,10 +1113,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1191,10 +1113,9 @@ std::vector<at::Tensor> fused_attn_bwd(
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
qkv_layout, bias_type, attn_mask_type, workspace.data(), workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// allocate memory for workspace // allocate memory for workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -1205,10 +1126,9 @@ std::vector<at::Tensor> fused_attn_bwd( ...@@ -1205,10 +1126,9 @@ std::vector<at::Tensor> fused_attn_bwd(
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
qkv_layout, bias_type, attn_mask_type, workspace.data(), workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers // destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment