Unverified Commit 70d3251f authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[C/PyTorch] Simplify THD offset tensors (#927)



* simplify offset tensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes; tests pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix C lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_offset with with_padding
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace with_padding with padded
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes after merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for fused attn fwd/bwd calls
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust spacing in docstring
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pytorch tests; fix paddle api
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn_biases
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix AttnFuncWithCP backward
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix jax
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix attn with CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix paddle
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 94a426b0
......@@ -832,24 +832,6 @@ def _run_dot_product_attention(
inp[i].requires_grad = True
inp_orig[i].requires_grad = True
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
......@@ -928,10 +910,8 @@ def _run_dot_product_attention(
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
seq_offsets_q=seq_offsets_q,
seq_offsets_k=seq_offsets_k,
seq_offsets_v=seq_offsets_v,
seq_offsets_o=seq_offsets_o,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
......@@ -1957,8 +1937,6 @@ class _custom_mha_fp8(torch.autograd.Function):
None,
None,
None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv[META_QKV],
fp8_meta["scaling_fwd"].scale_inv[META_S],
fp8_meta["scaling_fwd"].scale[META_S],
......@@ -2038,8 +2016,6 @@ class _custom_mha_fp8(torch.autograd.Function):
FusedAttnBackend["FP8"],
None,
None,
None,
None,
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
......
......@@ -8,9 +8,11 @@ import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
_cudnn_version,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
......@@ -58,7 +60,7 @@ model_configs_fused_attn = {
}
@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
......
......@@ -196,21 +196,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
......@@ -252,8 +247,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -272,21 +266,19 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
}
}
// NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
......@@ -338,8 +330,7 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -366,21 +357,18 @@ void nvte_fused_attn_bwd_qkvpacked(
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o,
const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
......@@ -426,8 +414,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -450,18 +438,16 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale,
float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
......@@ -519,9 +505,9 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_O, input_dO, input_Bias, output_S, output_dQ,
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q,
input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace,
stream, handle);
output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -549,9 +535,8 @@ void nvte_fused_attn_bwd_kvpacked(
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, const NVTETensor rng_state,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......@@ -560,10 +545,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
......@@ -601,8 +584,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_seq_offsets_q, input_seq_offsets_k,
input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -625,20 +608,17 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor seq_offsets_q,
const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor workspace, cudaStream_t stream) {
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_seq_offsets_q = reinterpret_cast<const Tensor *>(seq_offsets_q);
const Tensor *input_seq_offsets_k = reinterpret_cast<const Tensor *>(seq_offsets_k);
const Tensor *input_seq_offsets_v = reinterpret_cast<const Tensor *>(seq_offsets_v);
const Tensor *input_seq_offsets_o = reinterpret_cast<const Tensor *>(seq_offsets_o);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V);
......@@ -690,8 +670,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S,
output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv,
input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......
......@@ -22,8 +22,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q,
const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
......@@ -31,8 +30,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
......@@ -41,9 +39,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......@@ -52,18 +49,19 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v,
const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream,
cudnnHandle_t handle);
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
void fused_attn_arbitrary_seqlen_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
......@@ -73,9 +71,9 @@ void fused_attn_arbitrary_seqlen_bwd(
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O,
const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k,
const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -360,6 +360,38 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
}
}
// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d * cu_seqlens_q_padded[tid];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = offsets_q[tid];
offsets_v[tid] = offsets_q[tid];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d * cu_seqlens_q_padded[tid];
offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid];
offsets_v[tid] = offsets_k[tid];
break;
}
}
}
} // namespace fused_attn
// get cuDNN data type
......
......@@ -121,6 +121,11 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d, int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
......@@ -143,19 +143,17 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
TensorWrapper query_workspace_tensor;
if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) {
assert(q_max_seqlen == kv_max_seqlen);
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
......@@ -164,7 +162,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
......@@ -208,15 +205,13 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto qkv_shape = std::vector<size_t>{batch_size * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q_shape = std::vector<size_t>{batch_size * q_max_seqlen, attn_heads, head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -230,7 +225,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, query_workspace_tensor.data(),
nullptr);
......@@ -251,7 +245,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
query_workspace_tensor.data(), nullptr);
......@@ -340,10 +333,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), rng_state_tensor.data(), q_max_seqlen,
descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
rng_state_tensor.data(), q_max_seqlen, descriptor.is_training, descriptor.scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -355,7 +346,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
......@@ -373,7 +363,6 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
descriptor.is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
......@@ -426,7 +415,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, query_workspace_tensor.data(), nullptr);
......@@ -507,15 +495,13 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype);
auto dqkv = buffers[10];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
......@@ -533,7 +519,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) {
......@@ -559,7 +544,6 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
workspace_tensor.data(), stream);
......
......@@ -640,12 +640,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), QKV.stream());
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
......@@ -655,12 +654,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_qkvpacked(
te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), QKV.stream());
nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -727,24 +725,22 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
QKV.stream());
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd_qkvpacked(
te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack,
te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
QKV.stream());
nvte_fused_attn_bwd_qkvpacked(te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQKV.data(), te_dBias.data(),
te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen,
attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), QKV.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -818,12 +814,12 @@ void te_fused_attn_fwd_kvpacked(
auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast<size_t>(b + 1)}, DType::kInt32);
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -833,12 +829,12 @@ void te_fused_attn_fwd_kvpacked(
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd_kvpacked(
te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack,
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
nvte_fused_attn_fwd_kvpacked(te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -916,7 +912,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -929,7 +924,6 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum,
bias_type_enum, attn_mask_type_enum, workspace.data(), Q.stream());
......@@ -1001,10 +995,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -1018,10 +1011,9 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), te_rng_state.data(),
max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale,
p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum,
workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1100,10 +1092,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
......@@ -1113,10 +1104,9 @@ void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(),
te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(),
te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
dummy_seq_offsets.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(),
dummy_seq_offsets.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......
This diff is collapsed.
......@@ -85,10 +85,7 @@ def fused_attn_fwd_qkvpacked(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -124,14 +121,8 @@ def fused_attn_fwd_qkvpacked(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for QKV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -246,10 +237,7 @@ def fused_attn_fwd_qkvpacked(
cu_seqlens,
qkv,
qkv_dtype,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
d_scale_qkv,
d_scale_s,
q_scale_s,
......@@ -275,10 +263,7 @@ def fused_attn_bwd_qkvpacked(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -322,14 +307,8 @@ def fused_attn_bwd_qkvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_padded: torch.Tensor, default = None
cumulative sequence offsets for QKV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -419,10 +398,7 @@ def fused_attn_bwd_qkvpacked(
qkv_dtype,
dqkv_dtype,
aux_ctx_tensors,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_padded,
d_scale_qkv,
d_scale_s,
d_scale_o,
......@@ -449,10 +425,8 @@ def fused_attn_fwd_kvpacked(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -495,14 +469,10 @@ def fused_attn_fwd_kvpacked(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -621,10 +591,8 @@ def fused_attn_fwd_kvpacked(
q,
kv,
qkv_dtype,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
q_scale_s,
......@@ -653,10 +621,8 @@ def fused_attn_bwd_kvpacked(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -707,14 +673,10 @@ def fused_attn_bwd_kvpacked(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of QKV in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -811,10 +773,8 @@ def fused_attn_bwd_kvpacked(
qkv_dtype,
dqkv_dtype,
aux_ctx_tensors,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
d_scale_o,
......@@ -842,10 +802,8 @@ def fused_attn_fwd(
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
attn_bias: torch.Tensor = None,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
q_scale_s: torch.Tensor = None,
......@@ -892,14 +850,10 @@ def fused_attn_fwd(
attn_bias: torch.Tensor, default = None
input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias";
shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -1021,10 +975,8 @@ def fused_attn_fwd(
k,
v,
qkv_dtype,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
q_scale_s,
......@@ -1054,10 +1006,8 @@ def fused_attn_bwd(
dqkv_dtype: tex.DType,
aux_ctx_tensors: List[torch.Tensor],
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
seq_offsets_q: torch.Tensor = None,
seq_offsets_k: torch.Tensor = None,
seq_offsets_v: torch.Tensor = None,
seq_offsets_o: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
d_scale_qkv: torch.Tensor = None,
d_scale_s: torch.Tensor = None,
d_scale_o: torch.Tensor = None,
......@@ -1111,14 +1061,10 @@ def fused_attn_bwd(
e.g. aux_ctx_tensors = [M, ZInv, rng_state]
fused_attention_backend: tex.NVTE_Fused_Attn_Backend
please see FusedAttention module for details on supported backends.
seq_offsets_q: torch.Tensor, default = None
cu_seqlens_q_padded: torch.Tensor, default = None
cumulative sequence offsets for Q; shape [batch_size + 1]
seq_offsets_k: torch.Tensor, default = None
cumulative sequence offsets for K; shape [batch_size + 1]
seq_offsets_v: torch.Tensor, default = None
cumulative sequence offsets for V; shape [batch_size + 1]
seq_offsets_o: torch.Tensor, default = None
cumulative sequence offsets for O; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
d_scale_qkv: torch.Tensor, default = None
input tensor for the dequantization of Q, K and V in FP8 computations
d_scale_s: torch.Tensor, default = None
......@@ -1220,10 +1166,8 @@ def fused_attn_bwd(
qkv_dtype,
dqkv_dtype,
aux_ctx_tensors,
seq_offsets_q,
seq_offsets_k,
seq_offsets_v,
seq_offsets_o,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
d_scale_qkv,
d_scale_s,
d_scale_o,
......
......@@ -26,22 +26,19 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
......@@ -53,8 +50,8 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type,
const c10::optional<at::Tensor> seq_offsets_q, const c10::optional<at::Tensor> seq_offsets_k,
const c10::optional<at::Tensor> seq_offsets_v, const c10::optional<at::Tensor> seq_offsets_o,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
......@@ -67,27 +64,27 @@ std::vector<at::Tensor> fused_attn_bwd_kvpacked(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor KV, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
std::vector<at::Tensor> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor K, const at::Tensor V,
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_O, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_O,
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......@@ -95,14 +92,14 @@ std::vector<at::Tensor> fused_attn_bwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
const at::Tensor K, const at::Tensor V, const at::Tensor O, const at::Tensor dO,
const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const c10::optional<at::Tensor> seq_offsets_q,
const c10::optional<at::Tensor> seq_offsets_k, const c10::optional<at::Tensor> seq_offsets_v,
const c10::optional<at::Tensor> seq_offsets_o, const c10::optional<at::Tensor> descale_QKV,
const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
c10::optional<at::Tensor> amax_dQKV);
const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment